chore: covering higher level search logic with tests (#1910)

<!-- .github/pull_request_template.md -->

## Description
This PR covers the higher level search.py logic with unit tests. As a
part of the implementation we fully cover the following core logic:

- search.py
- get_search_type_tools (with all the core search types)
- search - prepare_search_results contract (testing behavior from
search.py interface)

## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Tests**
* Added comprehensive unit test coverage for search functionality,
including search type tool selection, search operations, and result
preparation workflows across multiple scenarios and edge cases.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Vasilije 2025-12-19 14:22:54 +01:00 committed by GitHub
commit 9b2b1a9c13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 980 additions and 0 deletions

View file

@ -0,0 +1,220 @@
import pytest
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.modules.search.types import SearchType
class _DummyCommunityRetriever:
def __init__(self, *args, **kwargs):
self.kwargs = kwargs
def get_completion(self, *args, **kwargs):
return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs}
def get_context(self, *args, **kwargs):
return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs}
@pytest.mark.asyncio
async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch):
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
async def _fake_select_search_type(query_text: str):
assert query_text == "hello"
return SearchType.CHUNKS
monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type)
tools = await mod.get_search_type_tools(SearchType.FEELING_LUCKY, query_text="hello")
assert len(tools) == 2
assert all(callable(t) for t in tools)
assert tools[0].__name__ == "get_completion"
assert tools[1].__name__ == "get_context"
assert tools[0].__self__.__class__ is ChunksRetriever
assert tools[1].__self__.__class__ is ChunksRetriever
@pytest.mark.asyncio
async def test_disallowed_cypher_search_types_raise(monkeypatch):
import cognee.modules.search.methods.get_search_type_tools as mod
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false")
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
await mod.get_search_type_tools(SearchType.CYPHER, query_text="MATCH (n) RETURN n")
with pytest.raises(UnsupportedSearchTypeError, match="disabled"):
await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="Find nodes")
@pytest.mark.asyncio
async def test_allowed_cypher_search_types_return_tools(monkeypatch):
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
tools = await mod.get_search_type_tools(SearchType.CYPHER, query_text="q")
assert len(tools) == 2
assert tools[0].__name__ == "get_completion"
assert tools[1].__name__ == "get_context"
assert tools[0].__self__.__class__ is CypherSearchRetriever
assert tools[1].__self__.__class__ is CypherSearchRetriever
@pytest.mark.asyncio
async def test_registered_community_retriever_is_used(monkeypatch):
"""
Integration point: community retrievers are loaded from the registry module and should
override the default mapping when present.
"""
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval import registered_community_retrievers as registry
monkeypatch.setattr(
registry,
"registered_community_retrievers",
{SearchType.SUMMARIES: _DummyCommunityRetriever},
)
tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=7)
assert len(tools) == 2
assert tools[0].__self__.__class__ is _DummyCommunityRetriever
assert tools[0].__self__.kwargs["top_k"] == 7
assert tools[1].__self__.__class__ is _DummyCommunityRetriever
assert tools[1].__self__.kwargs["top_k"] == 7
@pytest.mark.asyncio
async def test_unknown_query_type_raises_unsupported():
import cognee.modules.search.methods.get_search_type_tools as mod
with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"):
await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q")
@pytest.mark.asyncio
async def test_default_mapping_passes_top_k_to_retrievers():
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=4)
assert len(tools) == 2
assert tools[0].__self__.__class__ is SummariesRetriever
assert tools[1].__self__.__class__ is SummariesRetriever
assert tools[0].__self__.top_k == 4
assert tools[1].__self__.top_k == 4
@pytest.mark.asyncio
async def test_chunks_lexical_returns_jaccard_tools():
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever
tools = await mod.get_search_type_tools(SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3)
assert len(tools) == 2
assert tools[0].__self__.__class__ is JaccardChunksRetriever
assert tools[1].__self__.__class__ is JaccardChunksRetriever
assert tools[0].__self__ is tools[1].__self__
@pytest.mark.asyncio
async def test_coding_rules_uses_node_name_as_rules_nodeset_name():
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever
tools = await mod.get_search_type_tools(SearchType.CODING_RULES, query_text="q", node_name=[])
assert len(tools) == 1
assert tools[0].__name__ == "get_existing_rules"
assert tools[0].__self__.__class__ is CodingRulesRetriever
assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"]
@pytest.mark.asyncio
async def test_feedback_uses_last_k():
import cognee.modules.search.methods.get_search_type_tools as mod
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
tools = await mod.get_search_type_tools(SearchType.FEEDBACK, query_text="q", last_k=11)
assert len(tools) == 1
assert tools[0].__name__ == "add_feedback"
assert tools[0].__self__.__class__ is UserQAFeedback
assert tools[0].__self__.last_k == 11
@pytest.mark.asyncio
@pytest.mark.parametrize(
"query_type, expected_class_name, expected_method_names",
[
(SearchType.CHUNKS, "ChunksRetriever", ("get_completion", "get_context")),
(SearchType.RAG_COMPLETION, "CompletionRetriever", ("get_completion", "get_context")),
(SearchType.TRIPLET_COMPLETION, "TripletRetriever", ("get_completion", "get_context")),
(
SearchType.GRAPH_COMPLETION,
"GraphCompletionRetriever",
("get_completion", "get_context"),
),
(
SearchType.GRAPH_COMPLETION_COT,
"GraphCompletionCotRetriever",
("get_completion", "get_context"),
),
(
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION,
"GraphCompletionContextExtensionRetriever",
("get_completion", "get_context"),
),
(
SearchType.GRAPH_SUMMARY_COMPLETION,
"GraphSummaryCompletionRetriever",
("get_completion", "get_context"),
),
(SearchType.TEMPORAL, "TemporalRetriever", ("get_completion", "get_context")),
(
SearchType.NATURAL_LANGUAGE,
"NaturalLanguageRetriever",
("get_completion", "get_context"),
),
],
)
async def test_tool_construction_for_supported_search_types(
monkeypatch, query_type, expected_class_name, expected_method_names
):
import cognee.modules.search.methods.get_search_type_tools as mod
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
tools = await mod.get_search_type_tools(query_type, query_text="q")
assert len(tools) == 2
assert tools[0].__name__ == expected_method_names[0]
assert tools[1].__name__ == expected_method_names[1]
assert tools[0].__self__.__class__.__name__ == expected_class_name
assert tools[1].__self__.__class__.__name__ == expected_class_name
@pytest.mark.asyncio
async def test_some_completion_tools_are_callable_without_backends(monkeypatch):
"""
"Making search tools" should include that the returned callables are usable.
For retrievers that accept an explicit `context`, we can call get_completion without touching
DB/LLM backends.
"""
import cognee.modules.search.methods.get_search_type_tools as mod
monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true")
for query_type in [
SearchType.CHUNKS,
SearchType.SUMMARIES,
SearchType.CYPHER,
SearchType.NATURAL_LANGUAGE,
]:
tools = await mod.get_search_type_tools(query_type, query_text="q")
completion = tools[0]
result = await completion("q", context=["ok"])
assert result == ["ok"]

View file

@ -0,0 +1,464 @@
import types
from uuid import uuid4
import pytest
from cognee.modules.search.types import SearchType
def _make_user(user_id: str = "u1", tenant_id=None):
return types.SimpleNamespace(id=user_id, tenant_id=tenant_id)
def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None):
return types.SimpleNamespace(
id=dataset_id or uuid4(),
name=name,
tenant_id=tenant_id,
owner_id=owner_id or uuid4(),
)
@pytest.fixture
def search_mod():
import importlib
return importlib.import_module("cognee.modules.search.methods.search")
@pytest.fixture(autouse=True)
def _patch_side_effect_boundaries(monkeypatch, search_mod):
"""
Keep production logic; patch only unavoidable side-effect boundaries.
"""
async def dummy_log_query(_query_text, _query_type, _user_id):
return types.SimpleNamespace(id="qid-1")
async def dummy_log_result(*_args, **_kwargs):
return None
async def dummy_prepare_search_result(search_result):
if isinstance(search_result, tuple) and len(search_result) == 3:
result, context, datasets = search_result
return {"result": result, "context": context, "graphs": {}, "datasets": datasets}
return {"result": None, "context": None, "graphs": {}, "datasets": []}
monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None)
monkeypatch.setattr(search_mod, "log_query", dummy_log_query)
monkeypatch.setattr(search_mod, "log_result", dummy_log_result)
monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result)
yield
@pytest.mark.asyncio
async def test_search_no_access_control_flattens_single_list_result(monkeypatch, search_mod):
user = _make_user()
async def dummy_no_access_control_search(**_kwargs):
return (["r"], ["ctx"], [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
)
assert out == ["r"]
@pytest.mark.asyncio
async def test_search_no_access_control_non_list_result_returns_list(monkeypatch, search_mod):
"""
Covers the non-flattening back-compat branch in `search()`: if the single returned result is
not a list, `search()` returns a list of results instead of flattening.
"""
user = _make_user()
async def dummy_no_access_control_search(**_kwargs):
return ("r", ["ctx"], [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
)
assert out == ["r"]
@pytest.mark.asyncio
async def test_search_no_access_control_only_context_returns_context(monkeypatch, search_mod):
user = _make_user()
async def dummy_no_access_control_search(**_kwargs):
return (None, ["ctx"], [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False)
monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
only_context=True,
)
assert out == ["ctx"]
@pytest.mark.asyncio
async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, search_mod):
user = _make_user()
ds = _make_dataset(name="ds1", tenant_id="t1")
async def dummy_authorized_search(**kwargs):
assert kwargs["dataset_ids"] == [ds.id]
return [("r", ["ctx"], [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out == [
{
"search_result": ["r"],
"dataset_id": ds.id,
"dataset_name": "ds1",
"dataset_tenant_id": "t1",
"graphs": {},
}
]
@pytest.mark.asyncio
async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
monkeypatch, search_mod
):
user = _make_user()
ds = _make_dataset(name="ds1", tenant_id="t1")
async def dummy_authorized_search(**_kwargs):
return [(None, "ctx", [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
only_context=True,
)
assert out == [
{
"search_result": ["ctx"],
"dataset_id": ds.id,
"dataset_name": "ds1",
"dataset_tenant_id": "t1",
"graphs": {},
}
]
@pytest.mark.asyncio
async def test_search_access_control_use_combined_context_returns_combined_model(
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1", tenant_id="t1")
ds2 = _make_dataset(name="ds2", tenant_id="t1")
async def dummy_authorized_search(**_kwargs):
return ("answer", {"k": "v"}, [ds1, ds2])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds1.id, ds2.id],
user=user,
use_combined_context=True,
)
assert out.result == "answer"
assert out.context == {"k": "v"}
assert out.graphs == {}
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
@pytest.mark.asyncio
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
user = _make_user()
ds = _make_dataset(name="ds1")
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds]
expected = [("r", ["ctx"], [ds])]
async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
return expected
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
out = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds.id],
use_combined_context=False,
only_context=False,
)
assert out == expected
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["only_context"] is True
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
seen = {}
async def dummy_get_completion(query_text, context, session_id=None):
seen["query_text"] = query_text
seen["context"] = context
seen["session_id"] = session_id
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion, lambda *_a, **_k: None]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
session_id="s1",
)
assert combined_context == "a\nb"
assert completion == ["answer"]
assert datasets == [ds1, ds2]
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_keeps_non_string_context(
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
class DummyEdge:
pass
e1, e2 = DummyEdge(), DummyEdge()
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**_kwargs):
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
async def dummy_get_completion(query_text, context, session_id=None):
assert query_text == "q"
assert context == [e1, e2]
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
)
assert combined_context == [e1, e2]
assert completion == ["answer"]
assert datasets == [ds1, ds2]
@pytest.mark.asyncio
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
monkeypatch, search_mod
):
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
return None
class DummyGraphEngine:
async def is_empty(self):
return True
async def dummy_get_graph_engine():
return DummyGraphEngine()
async def dummy_get_dataset_data(dataset_id):
return [1] if dataset_id == ds1.id else []
calls = {"completion": 0, "context": 0}
async def dummy_get_context(_query_text: str):
calls["context"] += 1
return ["ctx"]
async def dummy_get_completion(_query_text: str, _context, session_id=None):
calls["completion"] += 1
assert session_id == "s1"
return ["r"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion, dummy_get_context]
monkeypatch.setattr(
search_mod,
"set_database_global_context_variables",
dummy_set_database_global_context_variables,
)
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
monkeypatch.setattr("cognee.modules.data.methods.get_dataset_data", dummy_get_dataset_data)
out = await search_mod.search_in_datasets_context(
search_datasets=[ds1, ds2],
query_type=SearchType.CHUNKS,
query_text="q",
context=["pre_ctx"],
session_id="s1",
)
assert out == [(["r"], ["pre_ctx"], [ds1]), (["r"], ["pre_ctx"], [ds2])]
assert calls == {"completion": 2, "context": 0}
@pytest.mark.asyncio
async def test_search_in_datasets_context_two_tool_only_context_true(monkeypatch, search_mod):
ds = _make_dataset(name="ds1")
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
return None
class DummyGraphEngine:
async def is_empty(self):
return False
async def dummy_get_graph_engine():
return DummyGraphEngine()
async def dummy_get_context(query_text: str):
assert query_text == "q"
return ["ctx"]
async def dummy_get_completion(*_args, **_kwargs):
raise AssertionError("Completion should not be called when only_context=True")
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion, dummy_get_context]
monkeypatch.setattr(
search_mod,
"set_database_global_context_variables",
dummy_set_database_global_context_variables,
)
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
out = await search_mod.search_in_datasets_context(
search_datasets=[ds],
query_type=SearchType.CHUNKS,
query_text="q",
only_context=True,
)
assert out == [(None, ["ctx"], [ds])]
@pytest.mark.asyncio
async def test_search_in_datasets_context_unknown_tool_path(monkeypatch, search_mod):
ds = _make_dataset(name="ds1")
async def dummy_set_database_global_context_variables(*_args, **_kwargs):
return None
class DummyGraphEngine:
async def is_empty(self):
return False
async def dummy_get_graph_engine():
return DummyGraphEngine()
async def dummy_unknown_tool(query_text: str):
assert query_text == "q"
return ["u"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_unknown_tool]
monkeypatch.setattr(
search_mod,
"set_database_global_context_variables",
dummy_set_database_global_context_variables,
)
monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
out = await search_mod.search_in_datasets_context(
search_datasets=[ds],
query_type=SearchType.CODING_RULES,
query_text="q",
)
assert out == [(["u"], "", [ds])]

View file

@ -0,0 +1,296 @@
## The Objective of these tests is to cover the search - prepare search results behavior (later to be removed)
import types
from uuid import uuid4
import pytest
from pydantic import BaseModel
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
from cognee.modules.search.types import SearchType
class DummyDataset(BaseModel):
id: object
name: str
tenant_id: str | None = None
owner_id: object
def _ds(name="ds1", tenant_id="t1"):
return DummyDataset(id=uuid4(), name=name, tenant_id=tenant_id, owner_id=uuid4())
def _edge(rel="rel", n1="A", n2="B"):
node1 = Node(str(uuid4()), attributes={"type": "Entity", "name": n1})
node2 = Node(str(uuid4()), attributes={"type": "Entity", "name": n2})
return Edge(node1, node2, attributes={"relationship_name": rel})
@pytest.fixture
def search_mod():
import importlib
return importlib.import_module("cognee.modules.search.methods.search")
@pytest.fixture(autouse=True)
def _patch_search_side_effects(monkeypatch, search_mod):
"""
These tests validate prepare_search_result behavior *through* search.py.
We only patch unavoidable side effects (telemetry + query/result logging).
"""
async def dummy_log_query(_query_text, _query_type, _user_id):
return types.SimpleNamespace(id="qid-1")
async def dummy_log_result(*_args, **_kwargs):
return None
monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None)
monkeypatch.setattr(search_mod, "log_query", dummy_log_query)
monkeypatch.setattr(search_mod, "log_result", dummy_log_result)
yield
@pytest.fixture(autouse=True)
def _patch_resolve_edges_to_text(monkeypatch):
"""
Keep graph-text conversion deterministic and lightweight.
"""
import importlib
psr_mod = importlib.import_module("cognee.modules.search.utils.prepare_search_result")
async def dummy_resolve_edges_to_text(_edges):
return "EDGE_TEXT"
monkeypatch.setattr(psr_mod, "resolve_edges_to_text", dummy_resolve_edges_to_text)
yield
@pytest.mark.asyncio
async def test_search_access_control_edges_context_produces_graphs_and_context_map(
monkeypatch, search_mod
):
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
context = [_edge("likes")]
async def dummy_authorized_search(**_kwargs):
return [(["answer"], context, [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out[0]["dataset_name"] == "ds1"
assert out[0]["dataset_tenant_id"] == "t1"
assert out[0]["graphs"] is not None
assert "ds1" in out[0]["graphs"]
assert out[0]["graphs"]["ds1"]["nodes"]
assert out[0]["graphs"]["ds1"]["edges"]
assert out[0]["search_result"] == ["answer"]
@pytest.mark.asyncio
async def test_search_access_control_insights_context_produces_graphs_and_null_result(
monkeypatch, search_mod
):
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
insights = [
(
{"id": "n1", "type": "Entity", "name": "Alice"},
{"relationship_name": "knows"},
{"id": "n2", "type": "Entity", "name": "Bob"},
)
]
async def dummy_authorized_search(**_kwargs):
return [(["something"], insights, [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out[0]["graphs"] is not None
assert "ds1" in out[0]["graphs"]
assert out[0]["search_result"] is None
@pytest.mark.asyncio
async def test_search_access_control_only_context_returns_context_text_map(monkeypatch, search_mod):
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
async def dummy_authorized_search(**_kwargs):
return [(None, ["a", "b"], [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
only_context=True,
)
assert out[0]["search_result"] == [{"ds1": "a\nb"}]
@pytest.mark.asyncio
async def test_search_access_control_results_edges_become_graph_result(monkeypatch, search_mod):
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
results = [_edge("connected_to")]
async def dummy_authorized_search(**_kwargs):
return [(results, "ctx", [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert isinstance(out[0]["search_result"][0], dict)
assert "nodes" in out[0]["search_result"][0]
assert "edges" in out[0]["search_result"][0]
@pytest.mark.asyncio
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
user = types.SimpleNamespace(id="u1", tenant_id=None)
async def dummy_authorized_search(**_kwargs):
return ("answer", "ctx", [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
use_combined_context=True,
)
assert out.result == "answer"
assert out.context == {"all available datasets": "ctx"}
assert out.datasets[0].name == "all available datasets"
@pytest.mark.asyncio
async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
"""Covers prepare_search_result(context is str) through search()."""
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
async def dummy_authorized_search(**_kwargs):
return [(["answer"], "plain context", [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out[0]["graphs"] is None
assert out[0]["search_result"] == ["answer"]
@pytest.mark.asyncio
async def test_search_access_control_context_empty_list_branch(monkeypatch, search_mod):
"""Covers prepare_search_result(context is empty list) through search()."""
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
async def dummy_authorized_search(**_kwargs):
return [(["answer"], [], [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out[0]["graphs"] is None
assert out[0]["search_result"] == ["answer"]
@pytest.mark.asyncio
async def test_search_access_control_multiple_results_list_branch(monkeypatch, search_mod):
"""Covers prepare_search_result(result list length > 1) through search()."""
user = types.SimpleNamespace(id="u1", tenant_id=None)
ds = _ds("ds1", "t1")
async def dummy_authorized_search(**_kwargs):
return [(["r1", "r2"], "ctx", [ds])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds.id],
user=user,
)
assert out[0]["search_result"] == [["r1", "r2"]]
@pytest.mark.asyncio
async def test_search_access_control_defaults_empty_datasets(monkeypatch, search_mod):
"""
Covers prepare_search_result(datasets empty list) through search().
Note: in access-control mode, search.py expects datasets[0] to have `tenant_id`,
but prepare_search_result defaults to SearchResultDataset which doesn't define it.
We assert the current behavior (it raises) so refactors don't silently change it.
"""
user = types.SimpleNamespace(id="u1", tenant_id=None)
async def dummy_authorized_search(**_kwargs):
return [(["answer"], "ctx", [])]
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
with pytest.raises(AttributeError, match="tenant_id"):
await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
)