refactor: remove combined search functionality

- Remove use_combined_context parameter from search functions
- Remove CombinedSearchResult class from types module
- Update API routers to remove combined search support
- Remove prepare_combined_context helper function
- Update tutorial notebook to remove use_combined_context usage
- Simplify search return types to always return List[SearchResult]

This removes the combined search feature which aggregated results
across multiple datasets into a single response. Users can still
search across multiple datasets and get results per dataset.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
vasilije 2026-01-12 07:43:57 +01:00
parent ab990f7c5c
commit 748b7aeaf5
6 changed files with 51 additions and 163 deletions

View file

@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
from cognee.modules.search.types import SearchType, SearchResult
from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
from cognee.modules.users.models import User
@ -31,7 +31,6 @@ class SearchPayloadDTO(InDTO):
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False)
use_combined_context: bool = Field(default=False)
def get_search_router() -> APIRouter:
@ -74,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
@router.post("", response_model=Union[List[SearchResult], List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.
@ -118,7 +117,6 @@ def get_search_router() -> APIRouter:
"node_name": payload.node_name,
"top_k": payload.top_k,
"only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"cognee_version": cognee_version,
},
)
@ -136,7 +134,6 @@ def get_search_router() -> APIRouter:
node_name=payload.node_name,
top_k=payload.top_k,
only_context=payload.only_context,
use_combined_context=payload.use_combined_context,
)
return jsonable_encoder(results)

View file

@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
from cognee.modules.search.types import SearchResult, SearchType
from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets
@ -32,11 +32,10 @@ async def search(
save_interaction: bool = False,
last_k: Optional[int] = 1,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]:
) -> List[SearchResult]:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -214,7 +213,6 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,

View file

@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
results = await cognee.search(
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
datasets=["python-development-with-cognee"],
use_combined_context=True, # Used to show reasoning graph visualization
)
print(results)

View file

@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import (
SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType,
)
from cognee.modules.search.operations import log_query, log_result
@ -45,11 +43,10 @@ async def search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]:
) -> List[SearchResult]:
"""
Args:
@ -90,7 +87,6 @@ async def search(
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
@ -127,87 +123,59 @@ async def search(
query.id,
json.dumps(
jsonable_encoder(
await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
[await prepare_search_result(search_result) for search_result in search_results]
)
),
user.id,
)
if use_combined_context:
prepared_search_results = await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
# This is for maintaining backwards compatibility
if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
return CombinedSearchResult(
result=result,
graphs=graphs,
context=context,
datasets=[
SearchResultDataset(
id=dataset.id,
name=dataset.name,
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
return_value.append(
{
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
for dataset in datasets
],
)
else:
return_value.append(
{
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
return return_value
else:
# This is for maintaining backwards compatibility
if backend_access_control_enabled():
return_value = []
return_value = []
if only_context:
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
return_value.append(
{
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
else:
return_value.append(
{
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
return return_value
return_value.append(prepared_search_results["context"])
else:
return_value = []
if only_context:
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
return_value.append(prepared_search_results["context"])
else:
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
async def authorized_search(
@ -223,14 +191,10 @@ async def authorized_search(
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode.
@ -240,70 +204,6 @@ async def authorized_search(
datasets=dataset_ids, permission_type="read", user=user
)
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context
datasets.extend(search_datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await search_in_datasets_context(
search_datasets=search_datasets,
@ -319,6 +219,7 @@ async def authorized_search(
only_context=only_context,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
return search_results

View file

@ -1,6 +1,6 @@
from uuid import UUID
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from typing import Any, Optional
class SearchResultDataset(BaseModel):
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
name: str
class CombinedSearchResult(BaseModel):
result: Optional[Any]
context: Dict[str, Any]
graphs: Optional[Dict[str, Any]] = {}
datasets: Optional[List[SearchResultDataset]] = None
class SearchResult(BaseModel):
search_result: Any
dataset_id: Optional[UUID]

View file

@ -1,2 +1,2 @@
from .SearchType import SearchType
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
from .SearchResult import SearchResult, SearchResultDataset