chore: covering higher level search logic with tests (#1910)
<!-- .github/pull_request_template.md -->
## Description
This PR covers the higher level search.py logic with unit tests. As a
part of the implementation we fully cover the following core logic:
- search.py
- get_search_type_tools (with all the core search types)
- search - prepare_search_results contract (testing behavior from
search.py interface)
## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->
## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):
## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->
## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages
## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Tests**
* Added comprehensive unit test coverage for search functionality,
including search type tool selection, search operations, and result
preparation workflows across multiple scenarios and edge cases.
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
commit
9b2b1a9c13
3 changed files with 980 additions and 0 deletions
220
cognee/tests/unit/modules/search/test_get_search_type_tools.py
Normal file
220
cognee/tests/unit/modules/search/test_get_search_type_tools.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
import pytest
|
||||
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
class _DummyCommunityRetriever:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_completion(self, *args, **kwargs):
|
||||
return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
||||
|
||||
def get_context(self, *args, **kwargs):
|
||||
return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch):
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
|
||||
|
||||
async def _fake_select_search_type(query_text: str):
|
||||
assert query_text == "hello"
|
||||
return SearchType.CHUNKS
|
||||
|
||||
monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type)
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.FEELING_LUCKY, query_text="hello")
|
||||
|
||||
assert len(tools) == 2
|
||||
assert all(callable(t) for t in tools)
|
||||
assert tools[0].__name__ == "get_completion"
|
||||
assert tools[1].__name__ == "get_context"
|
||||
assert tools[0].__self__.__class__ is ChunksRetriever
|
||||
assert tools[1].__self__.__class__ is ChunksRetriever
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disallowed_cypher_search_types_raise(monkeypatch):
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
|
||||
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false")
|
||||
|
||||
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
||||
await mod.get_search_type_tools(SearchType.CYPHER, query_text="MATCH (n) RETURN n")
|
||||
|
||||
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
|
||||
await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="Find nodes")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_cypher_search_types_return_tools(monkeypatch):
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
|
||||
|
||||
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.CYPHER, query_text="q")
|
||||
assert len(tools) == 2
|
||||
assert tools[0].__name__ == "get_completion"
|
||||
assert tools[1].__name__ == "get_context"
|
||||
assert tools[0].__self__.__class__ is CypherSearchRetriever
|
||||
assert tools[1].__self__.__class__ is CypherSearchRetriever
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registered_community_retriever_is_used(monkeypatch):
|
||||
"""
|
||||
Integration point: community retrievers are loaded from the registry module and should
|
||||
override the default mapping when present.
|
||||
"""
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval import registered_community_retrievers as registry
|
||||
|
||||
monkeypatch.setattr(
|
||||
registry,
|
||||
"registered_community_retrievers",
|
||||
{SearchType.SUMMARIES: _DummyCommunityRetriever},
|
||||
)
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=7)
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].__self__.__class__ is _DummyCommunityRetriever
|
||||
assert tools[0].__self__.kwargs["top_k"] == 7
|
||||
assert tools[1].__self__.__class__ is _DummyCommunityRetriever
|
||||
assert tools[1].__self__.kwargs["top_k"] == 7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_query_type_raises_unsupported():
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
|
||||
with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"):
|
||||
await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_mapping_passes_top_k_to_retrievers():
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=4)
|
||||
assert len(tools) == 2
|
||||
assert tools[0].__self__.__class__ is SummariesRetriever
|
||||
assert tools[1].__self__.__class__ is SummariesRetriever
|
||||
assert tools[0].__self__.top_k == 4
|
||||
assert tools[1].__self__.top_k == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunks_lexical_returns_jaccard_tools():
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3)
|
||||
assert len(tools) == 2
|
||||
assert tools[0].__self__.__class__ is JaccardChunksRetriever
|
||||
assert tools[1].__self__.__class__ is JaccardChunksRetriever
|
||||
assert tools[0].__self__ is tools[1].__self__
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_coding_rules_uses_node_name_as_rules_nodeset_name():
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.CODING_RULES, query_text="q", node_name=[])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].__name__ == "get_existing_rules"
|
||||
assert tools[0].__self__.__class__ is CodingRulesRetriever
|
||||
|
||||
assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_uses_last_k():
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
|
||||
tools = await mod.get_search_type_tools(SearchType.FEEDBACK, query_text="q", last_k=11)
|
||||
assert len(tools) == 1
|
||||
assert tools[0].__name__ == "add_feedback"
|
||||
assert tools[0].__self__.__class__ is UserQAFeedback
|
||||
assert tools[0].__self__.last_k == 11
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"query_type, expected_class_name, expected_method_names",
|
||||
[
|
||||
(SearchType.CHUNKS, "ChunksRetriever", ("get_completion", "get_context")),
|
||||
(SearchType.RAG_COMPLETION, "CompletionRetriever", ("get_completion", "get_context")),
|
||||
(SearchType.TRIPLET_COMPLETION, "TripletRetriever", ("get_completion", "get_context")),
|
||||
(
|
||||
SearchType.GRAPH_COMPLETION,
|
||||
"GraphCompletionRetriever",
|
||||
("get_completion", "get_context"),
|
||||
),
|
||||
(
|
||||
SearchType.GRAPH_COMPLETION_COT,
|
||||
"GraphCompletionCotRetriever",
|
||||
("get_completion", "get_context"),
|
||||
),
|
||||
(
|
||||
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
|
||||
"GraphCompletionContextExtensionRetriever",
|
||||
("get_completion", "get_context"),
|
||||
),
|
||||
(
|
||||
SearchType.GRAPH_SUMMARY_COMPLETION,
|
||||
"GraphSummaryCompletionRetriever",
|
||||
("get_completion", "get_context"),
|
||||
),
|
||||
(SearchType.TEMPORAL, "TemporalRetriever", ("get_completion", "get_context")),
|
||||
(
|
||||
SearchType.NATURAL_LANGUAGE,
|
||||
"NaturalLanguageRetriever",
|
||||
("get_completion", "get_context"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_tool_construction_for_supported_search_types(
|
||||
monkeypatch, query_type, expected_class_name, expected_method_names
|
||||
):
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
|
||||
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
|
||||
|
||||
tools = await mod.get_search_type_tools(query_type, query_text="q")
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].__name__ == expected_method_names[0]
|
||||
assert tools[1].__name__ == expected_method_names[1]
|
||||
assert tools[0].__self__.__class__.__name__ == expected_class_name
|
||||
assert tools[1].__self__.__class__.__name__ == expected_class_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_some_completion_tools_are_callable_without_backends(monkeypatch):
|
||||
"""
|
||||
"Making search tools" should include that the returned callables are usable.
|
||||
For retrievers that accept an explicit `context`, we can call get_completion without touching
|
||||
DB/LLM backends.
|
||||
"""
|
||||
import cognee.modules.search.methods.get_search_type_tools as mod
|
||||
|
||||
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
|
||||
|
||||
for query_type in [
|
||||
SearchType.CHUNKS,
|
||||
SearchType.SUMMARIES,
|
||||
SearchType.CYPHER,
|
||||
SearchType.NATURAL_LANGUAGE,
|
||||
]:
|
||||
tools = await mod.get_search_type_tools(query_type, query_text="q")
|
||||
completion = tools[0]
|
||||
result = await completion("q", context=["ok"])
|
||||
assert result == ["ok"]
|
||||
464
cognee/tests/unit/modules/search/test_search.py
Normal file
464
cognee/tests/unit/modules/search/test_search.py
Normal file
|
|
@ -0,0 +1,464 @@
|
|||
import types
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
def _make_user(user_id: str = "u1", tenant_id=None):
|
||||
return types.SimpleNamespace(id=user_id, tenant_id=tenant_id)
|
||||
|
||||
|
||||
def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None):
|
||||
return types.SimpleNamespace(
|
||||
id=dataset_id or uuid4(),
|
||||
name=name,
|
||||
tenant_id=tenant_id,
|
||||
owner_id=owner_id or uuid4(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_mod():
|
||||
import importlib
|
||||
|
||||
return importlib.import_module("cognee.modules.search.methods.search")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_side_effect_boundaries(monkeypatch, search_mod):
|
||||
"""
|
||||
Keep production logic; patch only unavoidable side-effect boundaries.
|
||||
"""
|
||||
|
||||
async def dummy_log_query(_query_text, _query_type, _user_id):
|
||||
return types.SimpleNamespace(id="qid-1")
|
||||
|
||||
async def dummy_log_result(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def dummy_prepare_search_result(search_result):
|
||||
if isinstance(search_result, tuple) and len(search_result) == 3:
|
||||
result, context, datasets = search_result
|
||||
return {"result": result, "context": context, "graphs": {}, "datasets": datasets}
|
||||
return {"result": None, "context": None, "graphs": {}, "datasets": []}
|
||||
|
||||
monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None)
|
||||
monkeypatch.setattr(search_mod, "log_query", dummy_log_query)
|
||||
monkeypatch.setattr(search_mod, "log_result", dummy_log_result)
|
||||
monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_access_control_flattens_single_list_result(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
|
||||
async def dummy_no_access_control_search(**_kwargs):
|
||||
return (["r"], ["ctx"], [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
|
||||
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out == ["r"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_access_control_non_list_result_returns_list(monkeypatch, search_mod):
|
||||
"""
|
||||
Covers the non-flattening back-compat branch in `search()`: if the single returned result is
|
||||
not a list, `search()` returns a list of results instead of flattening.
|
||||
"""
|
||||
user = _make_user()
|
||||
|
||||
async def dummy_no_access_control_search(**_kwargs):
|
||||
return ("r", ["ctx"], [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
|
||||
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out == ["r"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_no_access_control_only_context_returns_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
|
||||
async def dummy_no_access_control_search(**_kwargs):
|
||||
return (None, ["ctx"], [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
|
||||
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
only_context=True,
|
||||
)
|
||||
|
||||
assert out == ["ctx"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds = _make_dataset(name="ds1", tenant_id="t1")
|
||||
|
||||
async def dummy_authorized_search(**kwargs):
|
||||
assert kwargs["dataset_ids"] == [ds.id]
|
||||
return [("r", ["ctx"], [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out == [
|
||||
{
|
||||
"search_result": ["r"],
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": "ds1",
|
||||
"dataset_tenant_id": "t1",
|
||||
"graphs": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds = _make_dataset(name="ds1", tenant_id="t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(None, "ctx", [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
only_context=True,
|
||||
)
|
||||
|
||||
assert out == [
|
||||
{
|
||||
"search_result": ["ctx"],
|
||||
"dataset_id": ds.id,
|
||||
"dataset_name": "ds1",
|
||||
"dataset_tenant_id": "t1",
|
||||
"graphs": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_use_combined_context_returns_combined_model(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1", tenant_id="t1")
|
||||
ds2 = _make_dataset(name="ds2", tenant_id="t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", {"k": "v"}, [ds1, ds2])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"k": "v"}
|
||||
assert out.graphs == {}
|
||||
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds = _make_dataset(name="ds1")
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds]
|
||||
|
||||
expected = [("r", ["ctx"], [ds])]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
|
||||
return expected
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
|
||||
out = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds.id],
|
||||
use_combined_context=False,
|
||||
only_context=False,
|
||||
)
|
||||
|
||||
assert out == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**kwargs):
|
||||
assert kwargs["only_context"] is True
|
||||
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
|
||||
|
||||
seen = {}
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
seen["query_text"] = query_text
|
||||
seen["context"] = context
|
||||
seen["session_id"] = session_id
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion, lambda *_a, **_k: None]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
session_id="s1",
|
||||
)
|
||||
|
||||
assert combined_context == "a\nb"
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorized_search_use_combined_context_keeps_non_string_context(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = _make_user()
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
class DummyEdge:
|
||||
pass
|
||||
|
||||
e1, e2 = DummyEdge(), DummyEdge()
|
||||
|
||||
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
|
||||
return [ds1, ds2]
|
||||
|
||||
async def dummy_search_in_datasets_context(**_kwargs):
|
||||
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
|
||||
|
||||
async def dummy_get_completion(query_text, context, session_id=None):
|
||||
assert query_text == "q"
|
||||
assert context == [e1, e2]
|
||||
return ["answer"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
completion, combined_context, datasets = await search_mod.authorized_search(
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
user=user,
|
||||
dataset_ids=[ds1.id, ds2.id],
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert combined_context == [e1, e2]
|
||||
assert completion == ["answer"]
|
||||
assert datasets == [ds1, ds2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
ds1 = _make_dataset(name="ds1")
|
||||
ds2 = _make_dataset(name="ds2")
|
||||
|
||||
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class DummyGraphEngine:
|
||||
async def is_empty(self):
|
||||
return True
|
||||
|
||||
async def dummy_get_graph_engine():
|
||||
return DummyGraphEngine()
|
||||
|
||||
async def dummy_get_dataset_data(dataset_id):
|
||||
return [1] if dataset_id == ds1.id else []
|
||||
|
||||
calls = {"completion": 0, "context": 0}
|
||||
|
||||
async def dummy_get_context(_query_text: str):
|
||||
calls["context"] += 1
|
||||
return ["ctx"]
|
||||
|
||||
async def dummy_get_completion(_query_text: str, _context, session_id=None):
|
||||
calls["completion"] += 1
|
||||
assert session_id == "s1"
|
||||
return ["r"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion, dummy_get_context]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod,
|
||||
"set_database_global_context_variables",
|
||||
dummy_set_database_global_context_variables,
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
monkeypatch.setattr("cognee.modules.data.methods.get_dataset_data", dummy_get_dataset_data)
|
||||
|
||||
out = await search_mod.search_in_datasets_context(
|
||||
search_datasets=[ds1, ds2],
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
context=["pre_ctx"],
|
||||
session_id="s1",
|
||||
)
|
||||
|
||||
assert out == [(["r"], ["pre_ctx"], [ds1]), (["r"], ["pre_ctx"], [ds2])]
|
||||
assert calls == {"completion": 2, "context": 0}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_in_datasets_context_two_tool_only_context_true(monkeypatch, search_mod):
|
||||
ds = _make_dataset(name="ds1")
|
||||
|
||||
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class DummyGraphEngine:
|
||||
async def is_empty(self):
|
||||
return False
|
||||
|
||||
async def dummy_get_graph_engine():
|
||||
return DummyGraphEngine()
|
||||
|
||||
async def dummy_get_context(query_text: str):
|
||||
assert query_text == "q"
|
||||
return ["ctx"]
|
||||
|
||||
async def dummy_get_completion(*_args, **_kwargs):
|
||||
raise AssertionError("Completion should not be called when only_context=True")
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_get_completion, dummy_get_context]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod,
|
||||
"set_database_global_context_variables",
|
||||
dummy_set_database_global_context_variables,
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
out = await search_mod.search_in_datasets_context(
|
||||
search_datasets=[ds],
|
||||
query_type=SearchType.CHUNKS,
|
||||
query_text="q",
|
||||
only_context=True,
|
||||
)
|
||||
|
||||
assert out == [(None, ["ctx"], [ds])]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_in_datasets_context_unknown_tool_path(monkeypatch, search_mod):
|
||||
ds = _make_dataset(name="ds1")
|
||||
|
||||
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class DummyGraphEngine:
|
||||
async def is_empty(self):
|
||||
return False
|
||||
|
||||
async def dummy_get_graph_engine():
|
||||
return DummyGraphEngine()
|
||||
|
||||
async def dummy_unknown_tool(query_text: str):
|
||||
assert query_text == "q"
|
||||
return ["u"]
|
||||
|
||||
async def dummy_get_search_type_tools(**_kwargs):
|
||||
return [dummy_unknown_tool]
|
||||
|
||||
monkeypatch.setattr(
|
||||
search_mod,
|
||||
"set_database_global_context_variables",
|
||||
dummy_set_database_global_context_variables,
|
||||
)
|
||||
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
|
||||
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
|
||||
|
||||
out = await search_mod.search_in_datasets_context(
|
||||
search_datasets=[ds],
|
||||
query_type=SearchType.CODING_RULES,
|
||||
query_text="q",
|
||||
)
|
||||
|
||||
assert out == [(["u"], "", [ds])]
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
## The Objective of these tests is to cover the search - prepare search results behavior (later to be removed)
|
||||
|
||||
import types
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
||||
|
||||
class DummyDataset(BaseModel):
|
||||
id: object
|
||||
name: str
|
||||
tenant_id: str | None = None
|
||||
owner_id: object
|
||||
|
||||
|
||||
def _ds(name="ds1", tenant_id="t1"):
|
||||
return DummyDataset(id=uuid4(), name=name, tenant_id=tenant_id, owner_id=uuid4())
|
||||
|
||||
|
||||
def _edge(rel="rel", n1="A", n2="B"):
|
||||
node1 = Node(str(uuid4()), attributes={"type": "Entity", "name": n1})
|
||||
node2 = Node(str(uuid4()), attributes={"type": "Entity", "name": n2})
|
||||
return Edge(node1, node2, attributes={"relationship_name": rel})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_mod():
|
||||
import importlib
|
||||
|
||||
return importlib.import_module("cognee.modules.search.methods.search")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_search_side_effects(monkeypatch, search_mod):
|
||||
"""
|
||||
These tests validate prepare_search_result behavior *through* search.py.
|
||||
We only patch unavoidable side effects (telemetry + query/result logging).
|
||||
"""
|
||||
|
||||
async def dummy_log_query(_query_text, _query_type, _user_id):
|
||||
return types.SimpleNamespace(id="qid-1")
|
||||
|
||||
async def dummy_log_result(*_args, **_kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None)
|
||||
monkeypatch.setattr(search_mod, "log_query", dummy_log_query)
|
||||
monkeypatch.setattr(search_mod, "log_result", dummy_log_result)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_resolve_edges_to_text(monkeypatch):
|
||||
"""
|
||||
Keep graph-text conversion deterministic and lightweight.
|
||||
"""
|
||||
import importlib
|
||||
|
||||
psr_mod = importlib.import_module("cognee.modules.search.utils.prepare_search_result")
|
||||
|
||||
async def dummy_resolve_edges_to_text(_edges):
|
||||
return "EDGE_TEXT"
|
||||
|
||||
monkeypatch.setattr(psr_mod, "resolve_edges_to_text", dummy_resolve_edges_to_text)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_edges_context_produces_graphs_and_context_map(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
context = [_edge("likes")]
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["answer"], context, [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out[0]["dataset_name"] == "ds1"
|
||||
assert out[0]["dataset_tenant_id"] == "t1"
|
||||
assert out[0]["graphs"] is not None
|
||||
assert "ds1" in out[0]["graphs"]
|
||||
assert out[0]["graphs"]["ds1"]["nodes"]
|
||||
assert out[0]["graphs"]["ds1"]["edges"]
|
||||
assert out[0]["search_result"] == ["answer"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_insights_context_produces_graphs_and_null_result(
|
||||
monkeypatch, search_mod
|
||||
):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
insights = [
|
||||
(
|
||||
{"id": "n1", "type": "Entity", "name": "Alice"},
|
||||
{"relationship_name": "knows"},
|
||||
{"id": "n2", "type": "Entity", "name": "Bob"},
|
||||
)
|
||||
]
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["something"], insights, [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is not None
|
||||
assert "ds1" in out[0]["graphs"]
|
||||
assert out[0]["search_result"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_only_context_returns_context_text_map(monkeypatch, search_mod):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(None, ["a", "b"], [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
only_context=True,
|
||||
)
|
||||
|
||||
assert out[0]["search_result"] == [{"ds1": "a\nb"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_results_edges_become_graph_result(monkeypatch, search_mod):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
results = [_edge("connected_to")]
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(results, "ctx", [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert isinstance(out[0]["search_result"][0], dict)
|
||||
assert "nodes" in out[0]["search_result"][0]
|
||||
assert "edges" in out[0]["search_result"][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return ("answer", "ctx", [])
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
use_combined_context=True,
|
||||
)
|
||||
|
||||
assert out.result == "answer"
|
||||
assert out.context == {"all available datasets": "ctx"}
|
||||
assert out.datasets[0].name == "all available datasets"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
|
||||
"""Covers prepare_search_result(context is str) through search()."""
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["answer"], "plain context", [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is None
|
||||
assert out[0]["search_result"] == ["answer"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_context_empty_list_branch(monkeypatch, search_mod):
|
||||
"""Covers prepare_search_result(context is empty list) through search()."""
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["answer"], [], [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out[0]["graphs"] is None
|
||||
assert out[0]["search_result"] == ["answer"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_multiple_results_list_branch(monkeypatch, search_mod):
|
||||
"""Covers prepare_search_result(result list length > 1) through search()."""
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
ds = _ds("ds1", "t1")
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["r1", "r2"], "ctx", [ds])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
out = await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=[ds.id],
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert out[0]["search_result"] == [["r1", "r2"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_access_control_defaults_empty_datasets(monkeypatch, search_mod):
|
||||
"""
|
||||
Covers prepare_search_result(datasets empty list) through search().
|
||||
|
||||
Note: in access-control mode, search.py expects datasets[0] to have `tenant_id`,
|
||||
but prepare_search_result defaults to SearchResultDataset which doesn't define it.
|
||||
We assert the current behavior (it raises) so refactors don't silently change it.
|
||||
"""
|
||||
user = types.SimpleNamespace(id="u1", tenant_id=None)
|
||||
|
||||
async def dummy_authorized_search(**_kwargs):
|
||||
return [(["answer"], "ctx", [])]
|
||||
|
||||
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
|
||||
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
|
||||
|
||||
with pytest.raises(AttributeError, match="tenant_id"):
|
||||
await search_mod.search(
|
||||
query_text="q",
|
||||
query_type=SearchType.CHUNKS,
|
||||
dataset_ids=None,
|
||||
user=user,
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue