feat: Remove combined search (#1990)

- Remove use_combined_context parameter from search functions
- Remove CombinedSearchResult class from types module
- Update API routers to remove combined search support
- Remove prepare_combined_context helper function
- Update tutorial notebook to remove use_combined_context usage
- Simplify search return types to always return List[SearchResult]

This removes the combined search feature which aggregated results across
multiple datasets into a single response. Users can still search across
multiple datasets and get results per dataset.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

<!-- .github/pull_request_template.md -->

## Description
<!--
Please provide a clear, human-generated description of the changes in
this PR.
DO NOT use AI-generated descriptions. We want to understand your thought
process and reasoning.
-->

## Acceptance Criteria
<!--
* Key requirements to the new feature or modification;
* Proof that the changes work and meet the requirements;
* Include instructions on how to verify the changes. Describe how to
test it locally;
* Proof that it's sufficiently tested.
-->

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Breaking Changes**
* Search API response simplified: combined-context result type removed
and the legacy combined-context request flag eliminated, changing
response shapes.

* **New Features**
  * dataset_name added to each search result for clearer attribution.

* **Refactor**
* Search logic and return shapes streamlined for access-control and
per-dataset flows; telemetry and request parameters aligned.

* **Tests**
* Combined-context related tests removed or updated to reflect
simplified behavior.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Igor Ilic 2026-01-14 11:31:31 +01:00 committed by GitHub
commit f09f66e90d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 61 additions and 317 deletions

View file

@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult from cognee.modules.search.types import SearchType, SearchResult
from cognee.api.DTO import InDTO, OutDTO from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
from cognee.modules.users.models import User from cognee.modules.users.models import User
@ -31,7 +31,7 @@ class SearchPayloadDTO(InDTO):
node_name: Optional[list[str]] = Field(default=None, example=[]) node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10) top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False) only_context: bool = Field(default=False)
use_combined_context: bool = Field(default=False) verbose: bool = Field(default=False)
def get_search_router() -> APIRouter: def get_search_router() -> APIRouter:
@ -74,7 +74,7 @@ def get_search_router() -> APIRouter:
except Exception as error: except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)}) return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List]) @router.post("", response_model=Union[List[SearchResult], List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" """
Search for nodes in the graph database. Search for nodes in the graph database.
@ -118,7 +118,7 @@ def get_search_router() -> APIRouter:
"node_name": payload.node_name, "node_name": payload.node_name,
"top_k": payload.top_k, "top_k": payload.top_k,
"only_context": payload.only_context, "only_context": payload.only_context,
"use_combined_context": payload.use_combined_context, "verbose": payload.verbose,
"cognee_version": cognee_version, "cognee_version": cognee_version,
}, },
) )
@ -135,8 +135,8 @@ def get_search_router() -> APIRouter:
system_prompt=payload.system_prompt, system_prompt=payload.system_prompt,
node_name=payload.node_name, node_name=payload.node_name,
top_k=payload.top_k, top_k=payload.top_k,
verbose=payload.verbose,
only_context=payload.only_context, only_context=payload.only_context,
use_combined_context=payload.use_combined_context,
) )
return jsonable_encoder(results) return jsonable_encoder(results)

View file

@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult from cognee.modules.search.types import SearchResult, SearchType
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets from cognee.modules.data.methods import get_authorized_existing_datasets
@ -32,12 +32,11 @@ async def search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = 1, last_k: Optional[int] = 1,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
verbose: bool = False, verbose: bool = False,
) -> Union[List[SearchResult], CombinedSearchResult]: ) -> List[SearchResult]:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -217,7 +216,6 @@ async def search(
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context, only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,

View file

@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
results = await cognee.search( results = await cognee.search(
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?", "What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
datasets=["python-development-with-cognee"], datasets=["python-development-with-cognee"],
use_combined_context=True, # Used to show reasoning graph visualization
) )
print(results) print(results)

View file

@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import ( from cognee.modules.search.types import (
SearchResult, SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType, SearchType,
) )
from cognee.modules.search.operations import log_query, log_result from cognee.modules.search.operations import log_query, log_result
@ -45,12 +43,11 @@ async def search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
verbose: bool = False, verbose=False,
) -> Union[CombinedSearchResult, List[SearchResult]]: ) -> List[SearchResult]:
""" """
Args: Args:
@ -91,7 +88,6 @@ async def search(
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context, only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,
@ -128,93 +124,63 @@ async def search(
query.id, query.id,
json.dumps( json.dumps(
jsonable_encoder( jsonable_encoder(
await prepare_search_result( [await prepare_search_result(search_result) for search_result in search_results]
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
) )
), ),
user.id, user.id,
) )
if use_combined_context: # This is for maintaining backwards compatibility
# Note: combined context search must always be verbose and return a CombinedSearchResult with graphs info if backend_access_control_enabled():
prepared_search_results = await prepare_search_result( return_value = []
search_results[0] if isinstance(search_results, list) else search_results for search_result in search_results:
) prepared_search_results = await prepare_search_result(search_result)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
return CombinedSearchResult( result = prepared_search_results["result"]
result=result, graphs = prepared_search_results["graphs"]
graphs=graphs, context = prepared_search_results["context"]
context=context, datasets = prepared_search_results["datasets"]
datasets=[
SearchResultDataset( if only_context:
id=dataset.id, search_result_dict = {
name=dataset.name, "search_result": [context] if context else None,
) "dataset_id": datasets[0].id,
for dataset in datasets "dataset_name": datasets[0].name,
], "dataset_tenant_id": datasets[0].tenant_id,
) }
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
else:
search_result_dict = {
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
}
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
return return_value
else: else:
# This is for maintaining backwards compatibility return_value = []
if backend_access_control_enabled(): if only_context:
return_value = []
for search_result in search_results: for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result) prepared_search_results = await prepare_search_result(search_result)
return_value.append(prepared_search_results["context"])
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
search_result_dict = {
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
}
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
else:
search_result_dict = {
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
}
if verbose:
# Include graphs only in verbose mode
search_result_dict["graphs"] = graphs
return_value.append(search_result_dict)
return return_value
else: else:
return_value = [] for search_result in search_results:
if only_context: result, context, datasets = search_result
for search_result in search_results: return_value.append(result)
prepared_search_results = await prepare_search_result(search_result) # For maintaining backwards compatibility
return_value.append(prepared_search_results["context"]) if len(return_value) == 1 and isinstance(return_value[0], list):
else: return return_value[0]
for search_result in search_results: else:
result, context, datasets = search_result return return_value
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
async def authorized_search( async def authorized_search(
@ -230,14 +196,10 @@ async def authorized_search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[ ) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
""" """
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode. Not to be used outside of active access control mode.
@ -247,70 +209,6 @@ async def authorized_search(
datasets=dataset_ids, permission_type="read", user=user datasets=dataset_ids, permission_type="read", user=user
) )
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context
datasets.extend(search_datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions # Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await search_in_datasets_context( search_results = await search_in_datasets_context(
search_datasets=search_datasets, search_datasets=search_datasets,
@ -326,6 +224,7 @@ async def authorized_search(
only_context=only_context, only_context=only_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
) )
return search_results return search_results

View file

@ -1,6 +1,6 @@
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Dict, List, Optional from typing import Any, Optional
class SearchResultDataset(BaseModel): class SearchResultDataset(BaseModel):
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
name: str name: str
class CombinedSearchResult(BaseModel):
result: Optional[Any]
context: Dict[str, Any]
graphs: Optional[Dict[str, Any]] = {}
datasets: Optional[List[SearchResultDataset]] = None
class SearchResult(BaseModel): class SearchResult(BaseModel):
search_result: Any search_result: Any
dataset_id: Optional[UUID] dataset_id: Optional[UUID]

View file

@ -1,2 +1,2 @@
from .SearchType import SearchType from .SearchType import SearchType
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult from .SearchResult import SearchResult, SearchResultDataset

View file

@ -199,35 +199,7 @@ async def test_search_access_control_only_context_returns_dataset_shaped_dicts(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_access_control_use_combined_context_returns_combined_model( async def test_authorized_search_delegates_to_search_in_datasets_context(monkeypatch, search_mod):
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1", tenant_id="t1")
ds2 = _make_dataset(name="ds2", tenant_id="t1")
async def dummy_authorized_search(**_kwargs):
return ("answer", {"k": "v"}, [ds1, ds2])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=[ds1.id, ds2.id],
user=user,
use_combined_context=True,
)
assert out.result == "answer"
assert out.context == {"k": "v"}
assert out.graphs == {}
assert [d.id for d in out.datasets] == [ds1.id, ds2.id]
@pytest.mark.asyncio
async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod):
user = _make_user() user = _make_user()
ds = _make_dataset(name="ds1") ds = _make_dataset(name="ds1")
@ -237,7 +209,6 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
expected = [("r", ["ctx"], [ds])] expected = [("r", ["ctx"], [ds])]
async def dummy_search_in_datasets_context(**kwargs): async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["use_combined_context"] is False if "use_combined_context" in kwargs else True
return expected return expected
monkeypatch.setattr( monkeypatch.setattr(
@ -250,104 +221,12 @@ async def test_authorized_search_non_combined_delegates(monkeypatch, search_mod)
query_text="q", query_text="q",
user=user, user=user,
dataset_ids=[ds.id], dataset_ids=[ds.id],
use_combined_context=False,
only_context=False, only_context=False,
) )
assert out == expected assert out == expected
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_joins_string_context(monkeypatch, search_mod):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**kwargs):
assert kwargs["only_context"] is True
return [(None, ["a"], [ds1]), (None, ["b"], [ds2])]
seen = {}
async def dummy_get_completion(query_text, context, session_id=None):
seen["query_text"] = query_text
seen["context"] = context
seen["session_id"] = session_id
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion, lambda *_a, **_k: None]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
session_id="s1",
)
assert combined_context == "a\nb"
assert completion == ["answer"]
assert datasets == [ds1, ds2]
assert seen == {"query_text": "q", "context": "a\nb", "session_id": "s1"}
@pytest.mark.asyncio
async def test_authorized_search_use_combined_context_keeps_non_string_context(
monkeypatch, search_mod
):
user = _make_user()
ds1 = _make_dataset(name="ds1")
ds2 = _make_dataset(name="ds2")
class DummyEdge:
pass
e1, e2 = DummyEdge(), DummyEdge()
async def dummy_get_authorized_existing_datasets(*_args, **_kwargs):
return [ds1, ds2]
async def dummy_search_in_datasets_context(**_kwargs):
return [(None, [e1], [ds1]), (None, [e2], [ds2])]
async def dummy_get_completion(query_text, context, session_id=None):
assert query_text == "q"
assert context == [e1, e2]
return ["answer"]
async def dummy_get_search_type_tools(**_kwargs):
return [dummy_get_completion]
monkeypatch.setattr(
search_mod, "get_authorized_existing_datasets", dummy_get_authorized_existing_datasets
)
monkeypatch.setattr(search_mod, "search_in_datasets_context", dummy_search_in_datasets_context)
monkeypatch.setattr(search_mod, "get_search_type_tools", dummy_get_search_type_tools)
completion, combined_context, datasets = await search_mod.authorized_search(
query_type=SearchType.CHUNKS,
query_text="q",
user=user,
dataset_ids=[ds1.id, ds2.id],
use_combined_context=True,
)
assert combined_context == [e1, e2]
assert completion == ["answer"]
assert datasets == [ds1, ds2]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches( async def test_search_in_datasets_context_two_tool_context_override_and_is_empty_branches(
monkeypatch, search_mod monkeypatch, search_mod

View file

@ -183,30 +183,6 @@ async def test_search_access_control_results_edges_become_graph_result(monkeypat
assert "edges" in out[0]["search_result"][0] assert "edges" in out[0]["search_result"][0]
@pytest.mark.asyncio
async def test_search_use_combined_context_defaults_empty_datasets(monkeypatch, search_mod):
user = types.SimpleNamespace(id="u1", tenant_id=None)
async def dummy_authorized_search(**_kwargs):
return ("answer", "ctx", [])
monkeypatch.setattr(search_mod, "backend_access_control_enabled", lambda: True)
monkeypatch.setattr(search_mod, "authorized_search", dummy_authorized_search)
out = await search_mod.search(
query_text="q",
query_type=SearchType.CHUNKS,
dataset_ids=None,
user=user,
use_combined_context=True,
verbose=True,
)
assert out.result == "answer"
assert out.context == {"all available datasets": "ctx"}
assert out.datasets[0].name == "all available datasets"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_access_control_context_str_branch(monkeypatch, search_mod): async def test_search_access_control_context_str_branch(monkeypatch, search_mod):
"""Covers prepare_search_result(context is str) through search().""" """Covers prepare_search_result(context is str) through search()."""