feat: Remove combined search (#1990)
- Remove use_combined_context parameter from search functions - Remove CombinedSearchResult class from types module - Update API routers to remove combined search support - Remove prepare_combined_context helper function - Update tutorial notebook to remove use_combined_context usage - Simplify search return types to always return List[SearchResult] This removes the combined search feature which aggregated results across multiple datasets into a single response. Users can still search across multiple datasets and get results per dataset. 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- .github/pull_request_template.md --> ## Description <!-- Please provide a clear, human-generated description of the changes in this PR. DO NOT use AI-generated descriptions. We want to understand your thought process and reasoning. --> ## Acceptance Criteria <!-- * Key requirements to the new feature or modification; * Proof that the changes work and meet the requirements; * Include instructions on how to verify the changes. Describe how to test it locally; * Proof that it's sufficiently tested. --> ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Breaking Changes** * Search API response simplified: combined-context result type removed and the legacy combined-context request flag eliminated, changing response shapes. * **New Features** * dataset_name added to each search result for clearer attribution. * **Refactor** * Search logic and return shapes streamlined for access-control and per-dataset flows; telemetry and request parameters aligned. * **Tests** * Combined-context related tests removed or updated to reflect simplified behavior. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
f09f66e90d
8 changed files with 61 additions and 317 deletions
|
|
@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchType, SearchResult
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -31,7 +31,7 @@ class SearchPayloadDTO(InDTO):
|
|||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
only_context: bool = Field(default=False)
|
||||
use_combined_context: bool = Field(default=False)
|
||||
verbose: bool = Field(default=False)
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
|
|
@ -74,7 +74,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
@router.post("", response_model=Union[List[SearchResult], List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
@ -118,7 +118,7 @@ def get_search_router() -> APIRouter:
|
|||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
"use_combined_context": payload.use_combined_context,
|
||||
"verbose": payload.verbose,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
|
@ -135,8 +135,8 @@ def get_search_router() -> APIRouter:
|
|||
system_prompt=payload.system_prompt,
|
||||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
verbose=payload.verbose,
|
||||
only_context=payload.only_context,
|
||||
use_combined_context=payload.use_combined_context,
|
||||
)
|
||||
|
||||
return jsonable_encoder(results)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchResult, SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
|
|
@ -32,12 +32,11 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
verbose: bool = False,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -217,7 +216,6 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
|
|||
results = await cognee.search(
|
||||
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
|
||||
datasets=["python-development-with-cognee"],
|
||||
use_combined_context=True, # Used to show reasoning graph visualization
|
||||
)
|
||||
print(results)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
|
|||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import (
|
||||
SearchResult,
|
||||
CombinedSearchResult,
|
||||
SearchResultDataset,
|
||||
SearchType,
|
||||
)
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
|
|
@ -45,12 +43,11 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
verbose: bool = False,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
verbose=False,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
|
@ -91,7 +88,6 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
|
|
@ -128,93 +124,63 @@ async def search(
|
|||
query.id,
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
if use_combined_context
|
||||
else [
|
||||
await prepare_search_result(search_result) for search_result in search_results
|
||||
]
|
||||
[await prepare_search_result(search_result) for search_result in search_results]
|
||||
)
|
||||
),
|
||||
user.id,
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info
|
||||
prepared_search_results = await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
return CombinedSearchResult(
|
||||
result=result,
|
||||
graphs=graphs,
|
||||
context=context,
|
||||
datasets=[
|
||||
SearchResultDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
)
|
||||
for dataset in datasets
|
||||
],
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
search_result_dict = {
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
else:
|
||||
search_result_dict = {
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
return_value.append(search_result_dict)
|
||||
|
||||
return return_value
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
search_result_dict = {
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
else:
|
||||
search_result_dict = {
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
}
|
||||
if verbose:
|
||||
# Include graphs only in verbose mode
|
||||
search_result_dict["graphs"] = graphs
|
||||
|
||||
return_value.append(search_result_dict)
|
||||
|
||||
return return_value
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
@ -230,14 +196,10 @@ async def authorized_search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
]:
|
||||
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
|
||||
"""
|
||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||
Not to be used outside of active access control mode.
|
||||
|
|
@ -247,70 +209,6 @@ async def authorized_search(
|
|||
datasets=dataset_ids, permission_type="read", user=user
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
search_responses = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
datasets: List[Dataset] = []
|
||||
|
||||
for _, search_context, search_datasets in search_responses:
|
||||
for dataset in search_datasets:
|
||||
context[str(dataset.id)] = search_context
|
||||
|
||||
datasets.extend(search_datasets)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, _] = search_tools
|
||||
else:
|
||||
get_completion = search_tools[0]
|
||||
|
||||
def prepare_combined_context(
|
||||
context,
|
||||
) -> Union[List[Edge], str]:
|
||||
combined_context = []
|
||||
|
||||
for dataset_context in context.values():
|
||||
combined_context += dataset_context
|
||||
|
||||
if combined_context and isinstance(combined_context[0], str):
|
||||
return "\n".join(combined_context)
|
||||
|
||||
return combined_context
|
||||
|
||||
combined_context = prepare_combined_context(context)
|
||||
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
||||
|
||||
return completion, combined_context, datasets
|
||||
|
||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||
search_results = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
|
|
@ -326,6 +224,7 @@ async def authorized_search(
|
|||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class SearchResultDataset(BaseModel):
|
||||
|
|
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class CombinedSearchResult(BaseModel):
|
||||
result: Optional[Any]
|
||||
context: Dict[str, Any]
|
||||
graphs: Optional[Dict[str, Any]] = {}
|
||||
datasets: Optional[List[SearchResultDataset]] = None
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
search_result: Any
|
||||
dataset_id: Optional[UUID]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .SearchType import SearchType
|
||||
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
||||
from .SearchResult import SearchResult, SearchResultDataset
|
||||
|
|
|
|||
|
|
@ -199,35 +199,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_use_combined_context_returns_combined_model(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1", tenant_id="t1")
|
||||
ds2 = _make_dataset(name="ds2", tenant_id="t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", {"k": "v"}, [ds1, ds2])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"k": "v"}
|
||||
assert out.graphs == {}
|
||||
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
|
||||
async def test_authorized_search_delegates_to_search_in_datasets_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds = _make_dataset(name="ds1")
|
||||
|
||||
|
|
@ -237,7 +209,6 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
|
|||
expected = [("r", ["ctx"], [ds])]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -250,104 +221,12 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
|
|||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds.id],
|
||||
use_combined_context=False,
|
||||
only_context=False,
|
||||
)
|
||||
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["only_context"] is True
|
||||
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
|
||||
|
||||
seen = {}
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
seen["query_text"] = query_text
|
||||
seen["context"] = context
|
||||
seen["session_id"] = session_id
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion, lambda *_a, **_k: None]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
session_id="s1",
|
||||
)
|
||||
|
||||
assert combined_context == "a\nb"
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_keeps_non_string_context(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
class DummyEdge:
|
||||
pass
|
||||
|
||||
e1, e2 = DummyEdge(), DummyEdge()
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**_kwargs):
|
||||
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
assert query_text == "q"
|
||||
assert context == [e1, e2]
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert combined_context == [e1, e2]
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
|
||||
monkeypatch, search_mod
|
||||
|
|
|
|||
|
|
@ -183,30 +183,6 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
|
|||
assert "edges" in out[0]["search_result"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", "ctx", [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"all available datasets": "ctx"}
|
||||
assert out.datasets[0].name == "all available datasets"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
|
||||
"""Covers prepare_search_result(context is str) through search()."""
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue