refactor: remove combined search functionality
- Remove use_combined_context parameter from search functions - Remove CombinedSearchResult class from types module - Update API routers to remove combined search support - Remove prepare_combined_context helper function - Update tutorial notebook to remove use_combined_context usage - Simplify search return types to always return List[SearchResult] This removes the combined search feature which aggregated results across multiple datasets into a single response. Users can still search across multiple datasets and get results per dataset. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
ab990f7c5c
commit
748b7aeaf5
6 changed files with 51 additions and 163 deletions
|
|
@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
|
|||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchType, SearchResult
|
||||
from cognee.api.DTO import InDTO, OutDTO
|
||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
|
||||
from cognee.modules.users.models import User
|
||||
|
|
@ -31,7 +31,6 @@ class SearchPayloadDTO(InDTO):
|
|||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
only_context: bool = Field(default=False)
|
||||
use_combined_context: bool = Field(default=False)
|
||||
|
||||
|
||||
def get_search_router() -> APIRouter:
|
||||
|
|
@ -74,7 +73,7 @@ def get_search_router() -> APIRouter:
|
|||
except Exception as error:
|
||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||
|
||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
||||
@router.post("", response_model=Union[List[SearchResult], List])
|
||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||
"""
|
||||
Search for nodes in the graph database.
|
||||
|
|
@ -118,7 +117,6 @@ def get_search_router() -> APIRouter:
|
|||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
"only_context": payload.only_context,
|
||||
"use_combined_context": payload.use_combined_context,
|
||||
"cognee_version": cognee_version,
|
||||
},
|
||||
)
|
||||
|
|
@ -136,7 +134,6 @@ def get_search_router() -> APIRouter:
|
|||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
only_context=payload.only_context,
|
||||
use_combined_context=payload.use_combined_context,
|
||||
)
|
||||
|
||||
return jsonable_encoder(results)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
|
|||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
||||
from cognee.modules.search.types import SearchResult, SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
|
|
@ -32,11 +32,10 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = 1,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search and query the knowledge graph for insights, information, and connections.
|
||||
|
||||
|
|
@ -214,7 +213,6 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
|
|||
results = await cognee.search(
|
||||
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
|
||||
datasets=["python-development-with-cognee"],
|
||||
use_combined_context=True, # Used to show reasoning graph visualization
|
||||
)
|
||||
print(results)
|
||||
|
|
|
|||
|
|
@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
|
|||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.search.types import (
|
||||
SearchResult,
|
||||
CombinedSearchResult,
|
||||
SearchResultDataset,
|
||||
SearchType,
|
||||
)
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
|
|
@ -45,11 +43,10 @@ async def search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
|
|
@ -90,7 +87,6 @@ async def search(
|
|||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=only_context,
|
||||
use_combined_context=use_combined_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
|
|
@ -127,87 +123,59 @@ async def search(
|
|||
query.id,
|
||||
json.dumps(
|
||||
jsonable_encoder(
|
||||
await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
if use_combined_context
|
||||
else [
|
||||
await prepare_search_result(search_result) for search_result in search_results
|
||||
]
|
||||
[await prepare_search_result(search_result) for search_result in search_results]
|
||||
)
|
||||
),
|
||||
user.id,
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
prepared_search_results = await prepare_search_result(
|
||||
search_results[0] if isinstance(search_results, list) else search_results
|
||||
)
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
return CombinedSearchResult(
|
||||
result=result,
|
||||
graphs=graphs,
|
||||
context=context,
|
||||
datasets=[
|
||||
SearchResultDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
for dataset in datasets
|
||||
],
|
||||
)
|
||||
else:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
||||
result = prepared_search_results["result"]
|
||||
graphs = prepared_search_results["graphs"]
|
||||
context = prepared_search_results["context"]
|
||||
datasets = prepared_search_results["datasets"]
|
||||
|
||||
if only_context:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [context] if context else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return_value.append(
|
||||
{
|
||||
"search_result": [result] if result else None,
|
||||
"dataset_id": datasets[0].id,
|
||||
"dataset_name": datasets[0].name,
|
||||
"dataset_tenant_id": datasets[0].tenant_id,
|
||||
"graphs": graphs,
|
||||
}
|
||||
)
|
||||
return return_value
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
return_value = []
|
||||
if only_context:
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
return_value.append(prepared_search_results["context"])
|
||||
else:
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
for search_result in search_results:
|
||||
result, context, datasets = search_result
|
||||
return_value.append(result)
|
||||
# For maintaining backwards compatibility
|
||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||
return return_value[0]
|
||||
else:
|
||||
return return_value
|
||||
|
||||
|
||||
async def authorized_search(
|
||||
|
|
@ -223,14 +191,10 @@ async def authorized_search(
|
|||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
only_context: bool = False,
|
||||
use_combined_context: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
wide_search_top_k: Optional[int] = 100,
|
||||
triplet_distance_penalty: Optional[float] = 3.5,
|
||||
) -> Union[
|
||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
||||
]:
|
||||
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
|
||||
"""
|
||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||
Not to be used outside of active access control mode.
|
||||
|
|
@ -240,70 +204,6 @@ async def authorized_search(
|
|||
datasets=dataset_ids, permission_type="read", user=user
|
||||
)
|
||||
|
||||
if use_combined_context:
|
||||
search_responses = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
only_context=True,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
context = {}
|
||||
datasets: List[Dataset] = []
|
||||
|
||||
for _, search_context, search_datasets in search_responses:
|
||||
for dataset in search_datasets:
|
||||
context[str(dataset.id)] = search_context
|
||||
|
||||
datasets.extend(search_datasets)
|
||||
|
||||
specific_search_tools = await get_search_type_tools(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
system_prompt_path=system_prompt_path,
|
||||
system_prompt=system_prompt,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
search_tools = specific_search_tools
|
||||
if len(search_tools) == 2:
|
||||
[get_completion, _] = search_tools
|
||||
else:
|
||||
get_completion = search_tools[0]
|
||||
|
||||
def prepare_combined_context(
|
||||
context,
|
||||
) -> Union[List[Edge], str]:
|
||||
combined_context = []
|
||||
|
||||
for dataset_context in context.values():
|
||||
combined_context += dataset_context
|
||||
|
||||
if combined_context and isinstance(combined_context[0], str):
|
||||
return "\n".join(combined_context)
|
||||
|
||||
return combined_context
|
||||
|
||||
combined_context = prepare_combined_context(context)
|
||||
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
||||
|
||||
return completion, combined_context, datasets
|
||||
|
||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||
search_results = await search_in_datasets_context(
|
||||
search_datasets=search_datasets,
|
||||
|
|
@ -319,6 +219,7 @@ async def authorized_search(
|
|||
only_context=only_context,
|
||||
session_id=session_id,
|
||||
wide_search_top_k=wide_search_top_k,
|
||||
triplet_distance_penalty=triplet_distance_penalty,
|
||||
)
|
||||
|
||||
return search_results
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class SearchResultDataset(BaseModel):
|
||||
|
|
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
|
|||
name: str
|
||||
|
||||
|
||||
class CombinedSearchResult(BaseModel):
|
||||
result: Optional[Any]
|
||||
context: Dict[str, Any]
|
||||
graphs: Optional[Dict[str, Any]] = {}
|
||||
datasets: Optional[List[SearchResultDataset]] = None
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
search_result: Any
|
||||
dataset_id: Optional[UUID]
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from .SearchType import SearchType
|
||||
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
||||
from .SearchResult import SearchResult, SearchResultDataset
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue