diff --git a/cognee/tests/unit/modules/search/test_get_search_type_tools.py b/cognee/tests/unit/modules/search/test_get_search_type_tools.py new file mode 100644 index 000000000..3748a4e4b --- /dev/null +++ b/cognee/tests/unit/modules/search/test_get_search_type_tools.py @@ -0,0 +1,220 @@ +import pytest + +from cognee.modules.search.exceptions import UnsupportedSearchTypeError +from cognee.modules.search.types import SearchType + + +class _DummyCommunityRetriever: + def __init__(self, *args, **kwargs): + self.kwargs = kwargs + + def get_completion(self, *args, **kwargs): + return {"kind": "completion", "init": self.kwargs, "args": args, "kwargs": kwargs} + + def get_context(self, *args, **kwargs): + return {"kind": "context", "init": self.kwargs, "args": args, "kwargs": kwargs} + + +@pytest.mark.asyncio +async def test_feeling_lucky_delegates_to_select_search_type(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.chunks_retriever import ChunksRetriever + + async def _fake_select_search_type(query_text: str): + assert query_text == "hello" + return SearchType.CHUNKS + + monkeypatch.setattr(mod, "select_search_type", _fake_select_search_type) + + tools = await mod.get_search_type_tools(SearchType.FEELING_LUCKY, query_text="hello") + + assert len(tools) == 2 + assert all(callable(t) for t in tools) + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is ChunksRetriever + assert tools[1].__self__.__class__ is ChunksRetriever + + +@pytest.mark.asyncio +async def test_disallowed_cypher_search_types_raise(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "false") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.CYPHER, query_text="MATCH (n) RETURN n") + + with pytest.raises(UnsupportedSearchTypeError, match="disabled"): + await mod.get_search_type_tools(SearchType.NATURAL_LANGUAGE, query_text="Find nodes") + + +@pytest.mark.asyncio +async def test_allowed_cypher_search_types_return_tools(monkeypatch): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(SearchType.CYPHER, query_text="q") + assert len(tools) == 2 + assert tools[0].__name__ == "get_completion" + assert tools[1].__name__ == "get_context" + assert tools[0].__self__.__class__ is CypherSearchRetriever + assert tools[1].__self__.__class__ is CypherSearchRetriever + + +@pytest.mark.asyncio +async def test_registered_community_retriever_is_used(monkeypatch): + """ + Integration point: community retrievers are loaded from the registry module and should + override the default mapping when present. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval import registered_community_retrievers as registry + + monkeypatch.setattr( + registry, + "registered_community_retrievers", + {SearchType.SUMMARIES: _DummyCommunityRetriever}, + ) + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=7) + + assert len(tools) == 2 + assert tools[0].__self__.__class__ is _DummyCommunityRetriever + assert tools[0].__self__.kwargs["top_k"] == 7 + assert tools[1].__self__.__class__ is _DummyCommunityRetriever + assert tools[1].__self__.kwargs["top_k"] == 7 + + +@pytest.mark.asyncio +async def test_unknown_query_type_raises_unsupported(): + import cognee.modules.search.methods.get_search_type_tools as mod + + with pytest.raises(UnsupportedSearchTypeError, match="UNKNOWN_TYPE"): + await mod.get_search_type_tools("UNKNOWN_TYPE", query_text="q") + + +@pytest.mark.asyncio +async def test_default_mapping_passes_top_k_to_retrievers(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.summaries_retriever import SummariesRetriever + + tools = await mod.get_search_type_tools(SearchType.SUMMARIES, query_text="q", top_k=4) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is SummariesRetriever + assert tools[1].__self__.__class__ is SummariesRetriever + assert tools[0].__self__.top_k == 4 + assert tools[1].__self__.top_k == 4 + + +@pytest.mark.asyncio +async def test_chunks_lexical_returns_jaccard_tools(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.jaccard_retrival import JaccardChunksRetriever + + tools = await mod.get_search_type_tools(SearchType.CHUNKS_LEXICAL, query_text="q", top_k=3) + assert len(tools) == 2 + assert tools[0].__self__.__class__ is JaccardChunksRetriever + assert tools[1].__self__.__class__ is JaccardChunksRetriever + assert tools[0].__self__ is tools[1].__self__ + + +@pytest.mark.asyncio +async def test_coding_rules_uses_node_name_as_rules_nodeset_name(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.coding_rules_retriever import CodingRulesRetriever + + tools = await mod.get_search_type_tools(SearchType.CODING_RULES, query_text="q", node_name=[]) + assert len(tools) == 1 + assert tools[0].__name__ == "get_existing_rules" + assert tools[0].__self__.__class__ is CodingRulesRetriever + + assert tools[0].__self__.rules_nodeset_name == ["coding_agent_rules"] + + +@pytest.mark.asyncio +async def test_feedback_uses_last_k(): + import cognee.modules.search.methods.get_search_type_tools as mod + from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback + + tools = await mod.get_search_type_tools(SearchType.FEEDBACK, query_text="q", last_k=11) + assert len(tools) == 1 + assert tools[0].__name__ == "add_feedback" + assert tools[0].__self__.__class__ is UserQAFeedback + assert tools[0].__self__.last_k == 11 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "query_type, expected_class_name, expected_method_names", + [ + (SearchType.CHUNKS, "ChunksRetriever", ("get_completion", "get_context")), + (SearchType.RAG_COMPLETION, "CompletionRetriever", ("get_completion", "get_context")), + (SearchType.TRIPLET_COMPLETION, "TripletRetriever", ("get_completion", "get_context")), + ( + SearchType.GRAPH_COMPLETION, + "GraphCompletionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_COT, + "GraphCompletionCotRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION, + "GraphCompletionContextExtensionRetriever", + ("get_completion", "get_context"), + ), + ( + SearchType.GRAPH_SUMMARY_COMPLETION, + "GraphSummaryCompletionRetriever", + ("get_completion", "get_context"), + ), + (SearchType.TEMPORAL, "TemporalRetriever", ("get_completion", "get_context")), + ( + SearchType.NATURAL_LANGUAGE, + "NaturalLanguageRetriever", + ("get_completion", "get_context"), + ), + ], +) +async def test_tool_construction_for_supported_search_types( + monkeypatch, query_type, expected_class_name, expected_method_names +): + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + tools = await mod.get_search_type_tools(query_type, query_text="q") + + assert len(tools) == 2 + assert tools[0].__name__ == expected_method_names[0] + assert tools[1].__name__ == expected_method_names[1] + assert tools[0].__self__.__class__.__name__ == expected_class_name + assert tools[1].__self__.__class__.__name__ == expected_class_name + + +@pytest.mark.asyncio +async def test_some_completion_tools_are_callable_without_backends(monkeypatch): + """ + "Making search tools" should include that the returned callables are usable. + For retrievers that accept an explicit `context`, we can call get_completion without touching + DB/LLM backends. + """ + import cognee.modules.search.methods.get_search_type_tools as mod + + monkeypatch.setenv("ALLOW_CYPHER_QUERY", "true") + + for query_type in [ + SearchType.CHUNKS, + SearchType.SUMMARIES, + SearchType.CYPHER, + SearchType.NATURAL_LANGUAGE, + ]: + tools = await mod.get_search_type_tools(query_type, query_text="q") + completion = tools[0] + result = await completion("q", context=["ok"]) + assert result == ["ok"] diff --git a/cognee/tests/unit/modules/search/test_search.py b/cognee/tests/unit/modules/search/test_search.py new file mode 100644 index 000000000..175fd9aa4 --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search.py @@ -0,0 +1,464 @@ +import types +from uuid import uuid4 + +import pytest + +from cognee.modules.search.types import SearchType + + +def _make_user(user_id: str = "u1", tenant_id=None): + return types.SimpleNamespace(id=user_id, tenant_id=tenant_id) + + +def _make_dataset(*, name="ds", tenant_id="t1", dataset_id=None, owner_id=None): + return types.SimpleNamespace( + id=dataset_id or uuid4(), + name=name, + tenant_id=tenant_id, + owner_id=owner_id or uuid4(), + ) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_side_effect_boundaries(monkeypatch, search_mod): + """ + Keep production logic; patch only unavoidable side-effect boundaries. + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + async def dummy_prepare_search_result(search_result): + if isinstance(search_result, tuple) and len(search_result) == 3: + result, context, datasets = search_result + return {"result": result, "context": context, "graphs": {}, "datasets": datasets} + return {"result": None, "context": None, "graphs": {}, "datasets": []} + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + monkeypatch.setattr(search_mod, "prepare_search_result", dummy_prepare_search_result) + + yield + + +@pytest.mark.asyncio +async def test_search_no_access_control_flattens_single_list_result(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (["r"], ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_non_list_result_returns_list(monkeypatch, search_mod): + """ + Covers the non-flattening back-compat branch in `search()`: if the single returned result is + not a list, `search()` returns a list of results instead of flattening. + """ + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return ("r", ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + ) + + assert out == ["r"] + + +@pytest.mark.asyncio +async def test_search_no_access_control_only_context_returns_context(monkeypatch, search_mod): + user = _make_user() + + async def dummy_no_access_control_search(**_kwargs): + return (None, ["ctx"], []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: False) + monkeypatch.setattr(search_mod, "no_access_control_search", dummy_no_access_control_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + only_context=True, + ) + + assert out == ["ctx"] + + +@pytest.mark.asyncio +async def test_search_access_control_returns_dataset_shaped_dicts(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**kwargs): + assert kwargs["dataset_ids"] == [ds.id] + return [("r", ["ctx"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out == [ + { + "search_result": ["r"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_dataset_shaped_dicts( + monkeypatch, search_mod +): + user = _make_user() + ds = _make_dataset(name="ds1", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out == [ + { + "search_result": ["ctx"], + "dataset_id": ds.id, + "dataset_name": "ds1", + "dataset_tenant_id": "t1", + "graphs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_search_access_control_use_combined_context_returns_combined_model( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1", tenant_id="t1") + ds2 = _make_dataset(name="ds2", tenant_id="t1") + + async def dummy_authorized_search(**_kwargs): + return ("answer", {"k": "v"}, [ds1, ds2]) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds1.id, ds2.id], + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"k": "v"} + assert out.graphs == {} + assert [d.id for d in out.datasets] == [ds1.id, ds2.id] + + +@pytest.mark.asyncio +async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod): + user = _make_user() + ds = _make_dataset(name="ds1") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds] + + expected = [("r", ["ctx"], [ds])] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True + return expected + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + + out = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds.id], + use_combined_context=False, + only_context=False, + ) + + assert out == expected + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**kwargs): + assert kwargs["only_context"] is True + return [(None, ["a"], [ds1]), (None, ["b"], [ds2])] + + seen = {} + + async def dummy_get_completion(query_text, context, session_id=None): + seen["query_text"] = query_text + seen["context"] = context + seen["session_id"] = session_id + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, lambda *_a, **_k: None] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + session_id="s1", + ) + + assert combined_context == "a\nb" + assert completion == ["answer"] + assert datasets == [ds1, ds2] + assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"} + + +@pytest.mark.asyncio +async def test_authorized_search_use_combined_context_keeps_non_string_context( + monkeypatch, search_mod +): + user = _make_user() + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + class DummyEdge: + pass + + e1, e2 = DummyEdge(), DummyEdge() + + async def dummy_get_authorized_existing_datasets(*_args, **_kwargs): + return [ds1, ds2] + + async def dummy_search_in_datasets_context(**_kwargs): + return [(None, [e1], [ds1]), (None, [e2], [ds2])] + + async def dummy_get_completion(query_text, context, session_id=None): + assert query_text == "q" + assert context == [e1, e2] + return ["answer"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion] + + monkeypatch.setattr( + search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets + ) + monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + completion, combined_context, datasets = await search_mod.authorized_search( + query_type=SearchType.CHUNKS, + query_text="q", + user=user, + dataset_ids=[ds1.id, ds2.id], + use_combined_context=True, + ) + + assert combined_context == [e1, e2] + assert completion == ["answer"] + assert datasets == [ds1, ds2] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches( + monkeypatch, search_mod +): + ds1 = _make_dataset(name="ds1") + ds2 = _make_dataset(name="ds2") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return True + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_dataset_data(dataset_id): + return [1] if dataset_id == ds1.id else [] + + calls = {"completion": 0, "context": 0} + + async def dummy_get_context(_query_text: str): + calls["context"] += 1 + return ["ctx"] + + async def dummy_get_completion(_query_text: str, _context, session_id=None): + calls["completion"] += 1 + assert session_id == "s1" + return ["r"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + monkeypatch.setattr("cognee.modules.data.methods.get_dataset_data", dummy_get_dataset_data) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds1, ds2], + query_type=SearchType.CHUNKS, + query_text="q", + context=["pre_ctx"], + session_id="s1", + ) + + assert out == [(["r"], ["pre_ctx"], [ds1]), (["r"], ["pre_ctx"], [ds2])] + assert calls == {"completion": 2, "context": 0} + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_two_tool_only_context_true(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_get_context(query_text: str): + assert query_text == "q" + return ["ctx"] + + async def dummy_get_completion(*_args, **_kwargs): + raise AssertionError("Completion should not be called when only_context=True") + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_get_completion, dummy_get_context] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CHUNKS, + query_text="q", + only_context=True, + ) + + assert out == [(None, ["ctx"], [ds])] + + +@pytest.mark.asyncio +async def test_search_in_datasets_context_unknown_tool_path(monkeypatch, search_mod): + ds = _make_dataset(name="ds1") + + async def dummy_set_database_global_context_variables(*_args, **_kwargs): + return None + + class DummyGraphEngine: + async def is_empty(self): + return False + + async def dummy_get_graph_engine(): + return DummyGraphEngine() + + async def dummy_unknown_tool(query_text: str): + assert query_text == "q" + return ["u"] + + async def dummy_get_search_type_tools(**_kwargs): + return [dummy_unknown_tool] + + monkeypatch.setattr( + search_mod, + "set_database_global_context_variables", + dummy_set_database_global_context_variables, + ) + monkeypatch.setattr(search_mod, "get_graph_engine", dummy_get_graph_engine) + monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools) + + out = await search_mod.search_in_datasets_context( + search_datasets=[ds], + query_type=SearchType.CODING_RULES, + query_text="q", + ) + + assert out == [(["u"], "", [ds])] diff --git a/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py new file mode 100644 index 000000000..8700e6a1b --- /dev/null +++ b/cognee/tests/unit/modules/search/test_search_prepare_search_result_contract.py @@ -0,0 +1,296 @@ +## The Objective of these tests is to cover the search - prepare search results behavior (later to be removed) + +import types +from uuid import uuid4 + +import pytest +from pydantic import BaseModel + +from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node +from cognee.modules.search.types import SearchType + + +class DummyDataset(BaseModel): + id: object + name: str + tenant_id: str | None = None + owner_id: object + + +def _ds(name="ds1", tenant_id="t1"): + return DummyDataset(id=uuid4(), name=name, tenant_id=tenant_id, owner_id=uuid4()) + + +def _edge(rel="rel", n1="A", n2="B"): + node1 = Node(str(uuid4()), attributes={"type": "Entity", "name": n1}) + node2 = Node(str(uuid4()), attributes={"type": "Entity", "name": n2}) + return Edge(node1, node2, attributes={"relationship_name": rel}) + + +@pytest.fixture +def search_mod(): + import importlib + + return importlib.import_module("cognee.modules.search.methods.search") + + +@pytest.fixture(autouse=True) +def _patch_search_side_effects(monkeypatch, search_mod): + """ + These tests validate prepare_search_result behavior *through* search.py. + We only patch unavoidable side effects (telemetry + query/result logging). + """ + + async def dummy_log_query(_query_text, _query_type, _user_id): + return types.SimpleNamespace(id="qid-1") + + async def dummy_log_result(*_args, **_kwargs): + return None + + monkeypatch.setattr(search_mod, "send_telemetry", lambda *a, **k: None) + monkeypatch.setattr(search_mod, "log_query", dummy_log_query) + monkeypatch.setattr(search_mod, "log_result", dummy_log_result) + + yield + + +@pytest.fixture(autouse=True) +def _patch_resolve_edges_to_text(monkeypatch): + """ + Keep graph-text conversion deterministic and lightweight. + """ + import importlib + + psr_mod = importlib.import_module("cognee.modules.search.utils.prepare_search_result") + + async def dummy_resolve_edges_to_text(_edges): + return "EDGE_TEXT" + + monkeypatch.setattr(psr_mod, "resolve_edges_to_text", dummy_resolve_edges_to_text) + + yield + + +@pytest.mark.asyncio +async def test_search_access_control_edges_context_produces_graphs_and_context_map( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + context = [_edge("likes")] + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], context, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["dataset_name"] == "ds1" + assert out[0]["dataset_tenant_id"] == "t1" + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["graphs"]["ds1"]["nodes"] + assert out[0]["graphs"]["ds1"]["edges"] + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_insights_context_produces_graphs_and_null_result( + monkeypatch, search_mod +): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + insights = [ + ( + {"id": "n1", "type": "Entity", "name": "Alice"}, + {"relationship_name": "knows"}, + {"id": "n2", "type": "Entity", "name": "Bob"}, + ) + ] + + async def dummy_authorized_search(**_kwargs): + return [(["something"], insights, [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is not None + assert "ds1" in out[0]["graphs"] + assert out[0]["search_result"] is None + + +@pytest.mark.asyncio +async def test_search_access_control_only_context_returns_context_text_map(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(None, ["a", "b"], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + only_context=True, + ) + + assert out[0]["search_result"] == [{"ds1": "a\nb"}] + + +@pytest.mark.asyncio +async def test_search_access_control_results_edges_become_graph_result(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + results = [_edge("connected_to")] + + async def dummy_authorized_search(**_kwargs): + return [(results, "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert isinstance(out[0]["search_result"][0], dict) + assert "nodes" in out[0]["search_result"][0] + assert "edges" in out[0]["search_result"][0] + + +@pytest.mark.asyncio +async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod): + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return ("answer", "ctx", []) + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + use_combined_context=True, + ) + + assert out.result == "answer" + assert out.context == {"all available datasets": "ctx"} + assert out.datasets[0].name == "all available datasets" + + +@pytest.mark.asyncio +async def test_search_access_control_context_str_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is str) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "plain context", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_context_empty_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(context is empty list) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], [], [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["graphs"] is None + assert out[0]["search_result"] == ["answer"] + + +@pytest.mark.asyncio +async def test_search_access_control_multiple_results_list_branch(monkeypatch, search_mod): + """Covers prepare_search_result(result list length > 1) through search().""" + user = types.SimpleNamespace(id="u1", tenant_id=None) + ds = _ds("ds1", "t1") + + async def dummy_authorized_search(**_kwargs): + return [(["r1", "r2"], "ctx", [ds])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + out = await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=[ds.id], + user=user, + ) + + assert out[0]["search_result"] == [["r1", "r2"]] + + +@pytest.mark.asyncio +async def test_search_access_control_defaults_empty_datasets(monkeypatch, search_mod): + """ + Covers prepare_search_result(datasets empty list) through search(). + + Note: in access-control mode, search.py expects datasets[0] to have `tenant_id`, + but prepare_search_result defaults to SearchResultDataset which doesn't define it. + We assert the current behavior (it raises) so refactors don't silently change it. + """ + user = types.SimpleNamespace(id="u1", tenant_id=None) + + async def dummy_authorized_search(**_kwargs): + return [(["answer"], "ctx", [])] + + monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True) + monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search) + + with pytest.raises(AttributeError, match="tenant_id"): + await search_mod.search( + query_text="q", + query_type=SearchType.CHUNKS, + dataset_ids=None, + user=user, + )