refactor: remove combined search functionality

- Remove use_combined_context parameter from search functions
- Remove CombinedSearchResult class from types module
- Update API routers to remove combined search support
- Remove prepare_combined_context helper function
- Update tutorial notebook to remove use_combined_context usage
- Simplify search return types to always return List[SearchResult]

This removes the combined search feature which aggregated results
across multiple datasets into a single response. Users can still
search across multiple datasets and get results per dataset.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
vasilije 2026-01-12 07:43:57 +01:00
parent ab990f7c5c
commit 748b7aeaf5
6 changed files with 51 additions and 163 deletions

View file

@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult from cognee.modules.search.types import SearchType, SearchResult
from cognee.api.DTO import InDTO, OutDTO from cognee.api.DTO import InDTO, OutDTO
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
from cognee.modules.users.models import User from cognee.modules.users.models import User
@ -31,7 +31,6 @@ class SearchPayloadDTO(InDTO):
node_name: Optional[list[str]] = Field(default=None, example=[]) node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10) top_k: Optional[int] = Field(default=10)
only_context: bool = Field(default=False) only_context: bool = Field(default=False)
use_combined_context: bool = Field(default=False)
def get_search_router() -> APIRouter: def get_search_router() -> APIRouter:
@ -74,7 +73,7 @@ def get_search_router() -> APIRouter:
except Exception as error: except Exception as error:
return JSONResponse(status_code=500, content={"error": str(error)}) return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List]) @router.post("", response_model=Union[List[SearchResult], List])
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
""" """
Search for nodes in the graph database. Search for nodes in the graph database.
@ -118,7 +117,6 @@ def get_search_router() -> APIRouter:
"node_name": payload.node_name, "node_name": payload.node_name,
"top_k": payload.top_k, "top_k": payload.top_k,
"only_context": payload.only_context, "only_context": payload.only_context,
"use_combined_context": payload.use_combined_context,
"cognee_version": cognee_version, "cognee_version": cognee_version,
}, },
) )
@ -136,7 +134,6 @@ def get_search_router() -> APIRouter:
node_name=payload.node_name, node_name=payload.node_name,
top_k=payload.top_k, top_k=payload.top_k,
only_context=payload.only_context, only_context=payload.only_context,
use_combined_context=payload.use_combined_context,
) )
return jsonable_encoder(results) return jsonable_encoder(results)

View file

@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User from cognee.modules.users.models import User
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult from cognee.modules.search.types import SearchResult, SearchType
from cognee.modules.users.methods import get_default_user from cognee.modules.users.methods import get_default_user
from cognee.modules.search.methods import search as search_function from cognee.modules.search.methods import search as search_function
from cognee.modules.data.methods import get_authorized_existing_datasets from cognee.modules.data.methods import get_authorized_existing_datasets
@ -32,11 +32,10 @@ async def search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = 1, last_k: Optional[int] = 1,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[List[SearchResult], CombinedSearchResult]: ) -> List[SearchResult]:
""" """
Search and query the knowledge graph for insights, information, and connections. Search and query the knowledge graph for insights, information, and connections.
@ -214,7 +213,6 @@ async def search(
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context, only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,

View file

@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
results = await cognee.search( results = await cognee.search(
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?", "What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
datasets=["python-development-with-cognee"], datasets=["python-development-with-cognee"],
use_combined_context=True, # Used to show reasoning graph visualization
) )
print(results) print(results)

View file

@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import ( from cognee.modules.search.types import (
SearchResult, SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType, SearchType,
) )
from cognee.modules.search.operations import log_query, log_result from cognee.modules.search.operations import log_query, log_result
@ -45,11 +43,10 @@ async def search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[CombinedSearchResult, List[SearchResult]]: ) -> List[SearchResult]:
""" """
Args: Args:
@ -90,7 +87,6 @@ async def search(
save_interaction=save_interaction, save_interaction=save_interaction,
last_k=last_k, last_k=last_k,
only_context=only_context, only_context=only_context,
use_combined_context=use_combined_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty, triplet_distance_penalty=triplet_distance_penalty,
@ -127,87 +123,59 @@ async def search(
query.id, query.id,
json.dumps( json.dumps(
jsonable_encoder( jsonable_encoder(
await prepare_search_result( [await prepare_search_result(search_result) for search_result in search_results]
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
) )
), ),
user.id, user.id,
) )
if use_combined_context: # This is for maintaining backwards compatibility
prepared_search_results = await prepare_search_result( if backend_access_control_enabled():
search_results[0] if isinstance(search_results, list) else search_results return_value = []
) for search_result in search_results:
result = prepared_search_results["result"] prepared_search_results = await prepare_search_result(search_result)
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
return CombinedSearchResult( result = prepared_search_results["result"]
result=result, graphs = prepared_search_results["graphs"]
graphs=graphs, context = prepared_search_results["context"]
context=context, datasets = prepared_search_results["datasets"]
datasets=[
SearchResultDataset( if only_context:
id=dataset.id, return_value.append(
name=dataset.name, {
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
) )
for dataset in datasets else:
], return_value.append(
) {
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
return return_value
else: else:
# This is for maintaining backwards compatibility return_value = []
if backend_access_control_enabled(): if only_context:
return_value = []
for search_result in search_results: for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result) prepared_search_results = await prepare_search_result(search_result)
return_value.append(prepared_search_results["context"])
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
return_value.append(
{
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
else:
return_value.append(
{
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"dataset_tenant_id": datasets[0].tenant_id,
"graphs": graphs,
}
)
return return_value
else: else:
return_value = [] for search_result in search_results:
if only_context: result, context, datasets = search_result
for search_result in search_results: return_value.append(result)
prepared_search_results = await prepare_search_result(search_result) # For maintaining backwards compatibility
return_value.append(prepared_search_results["context"]) if len(return_value) == 1 and isinstance(return_value[0], list):
else: return return_value[0]
for search_result in search_results: else:
result, context, datasets = search_result return return_value
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
async def authorized_search( async def authorized_search(
@ -223,14 +191,10 @@ async def authorized_search(
save_interaction: bool = False, save_interaction: bool = False,
last_k: Optional[int] = None, last_k: Optional[int] = None,
only_context: bool = False, only_context: bool = False,
use_combined_context: bool = False,
session_id: Optional[str] = None, session_id: Optional[str] = None,
wide_search_top_k: Optional[int] = 100, wide_search_top_k: Optional[int] = 100,
triplet_distance_penalty: Optional[float] = 3.5, triplet_distance_penalty: Optional[float] = 3.5,
) -> Union[ ) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
""" """
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode. Not to be used outside of active access control mode.
@ -240,70 +204,6 @@ async def authorized_search(
datasets=dataset_ids, permission_type="read", user=user datasets=dataset_ids, permission_type="read", user=user
) )
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
session_id=session_id,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context
datasets.extend(search_datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context, session_id=session_id)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions # Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await search_in_datasets_context( search_results = await search_in_datasets_context(
search_datasets=search_datasets, search_datasets=search_datasets,
@ -319,6 +219,7 @@ async def authorized_search(
only_context=only_context, only_context=only_context,
session_id=session_id, session_id=session_id,
wide_search_top_k=wide_search_top_k, wide_search_top_k=wide_search_top_k,
triplet_distance_penalty=triplet_distance_penalty,
) )
return search_results return search_results

View file

@ -1,6 +1,6 @@
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any, Dict, List, Optional from typing import Any, Optional
class SearchResultDataset(BaseModel): class SearchResultDataset(BaseModel):
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
name: str name: str
class CombinedSearchResult(BaseModel):
result: Optional[Any]
context: Dict[str, Any]
graphs: Optional[Dict[str, Any]] = {}
datasets: Optional[List[SearchResultDataset]] = None
class SearchResult(BaseModel): class SearchResult(BaseModel):
search_result: Any search_result: Any
dataset_id: Optional[UUID] dataset_id: Optional[UUID]

View file

@ -1,2 +1,2 @@
from .SearchType import SearchType from .SearchType import SearchType
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult from .SearchResult import SearchResult, SearchResultDataset