refactor: remove combined search functionality
- Remove use_combined_context parameter from search functions - Remove CombinedSearchResult class from types module - Update API routers to remove combined search support - Remove prepare_combined_context helper function - Update tutorial notebook to remove use_combined_context usage - Simplify search return types to always return List[SearchResult] This removes the combined search feature which aggregated results across multiple datasets into a single response. Users can still search across multiple datasets and get results per dataset. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
ab990f7c5c
commit
748b7aeaf5
6 changed files with 51 additions and 163 deletions
|
|
@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
||||||
from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult
|
from cognee.modules.search.types import SearchType, SearchResult
|
||||||
from cognee.api.DTO import InDTO, OutDTO
|
from cognee.api.DTO import InDTO, OutDTO
|
||||||
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
|
from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
|
|
@ -31,7 +31,6 @@ class SearchPayloadDTO(InDTO):
|
||||||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||||
top_k: Optional[int] = Field(default=10)
|
top_k: Optional[int] = Field(default=10)
|
||||||
only_context: bool = Field(default=False)
|
only_context: bool = Field(default=False)
|
||||||
use_combined_context: bool = Field(default=False)
|
|
||||||
|
|
||||||
|
|
||||||
def get_search_router() -> APIRouter:
|
def get_search_router() -> APIRouter:
|
||||||
|
|
@ -74,7 +73,7 @@ def get_search_router() -> APIRouter:
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=500, content={"error": str(error)})
|
return JSONResponse(status_code=500, content={"error": str(error)})
|
||||||
|
|
||||||
@router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List])
|
@router.post("", response_model=Union[List[SearchResult], List])
|
||||||
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
|
||||||
"""
|
"""
|
||||||
Search for nodes in the graph database.
|
Search for nodes in the graph database.
|
||||||
|
|
@ -118,7 +117,6 @@ def get_search_router() -> APIRouter:
|
||||||
"node_name": payload.node_name,
|
"node_name": payload.node_name,
|
||||||
"top_k": payload.top_k,
|
"top_k": payload.top_k,
|
||||||
"only_context": payload.only_context,
|
"only_context": payload.only_context,
|
||||||
"use_combined_context": payload.use_combined_context,
|
|
||||||
"cognee_version": cognee_version,
|
"cognee_version": cognee_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -136,7 +134,6 @@ def get_search_router() -> APIRouter:
|
||||||
node_name=payload.node_name,
|
node_name=payload.node_name,
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
only_context=payload.only_context,
|
only_context=payload.only_context,
|
||||||
use_combined_context=payload.use_combined_context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonable_encoder(results)
|
return jsonable_encoder(results)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type
|
||||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult
|
from cognee.modules.search.types import SearchResult, SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
from cognee.modules.search.methods import search as search_function
|
from cognee.modules.search.methods import search as search_function
|
||||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||||
|
|
@ -32,11 +32,10 @@ async def search(
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = 1,
|
last_k: Optional[int] = 1,
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
use_combined_context: bool = False,
|
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
wide_search_top_k: Optional[int] = 100,
|
wide_search_top_k: Optional[int] = 100,
|
||||||
triplet_distance_penalty: Optional[float] = 3.5,
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> Union[List[SearchResult], CombinedSearchResult]:
|
) -> List[SearchResult]:
|
||||||
"""
|
"""
|
||||||
Search and query the knowledge graph for insights, information, and connections.
|
Search and query the knowledge graph for insights, information, and connections.
|
||||||
|
|
||||||
|
|
@ -214,7 +213,6 @@ async def search(
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
use_combined_context=use_combined_context,
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
wide_search_top_k=wide_search_top_k,
|
wide_search_top_k=wide_search_top_k,
|
||||||
triplet_distance_penalty=triplet_distance_penalty,
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni
|
||||||
results = await cognee.search(
|
results = await cognee.search(
|
||||||
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
|
"What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?",
|
||||||
datasets=["python-development-with-cognee"],
|
datasets=["python-development-with-cognee"],
|
||||||
use_combined_context=True, # Used to show reasoning graph visualization
|
|
||||||
)
|
)
|
||||||
print(results)
|
print(results)
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
from cognee.modules.search.types import (
|
from cognee.modules.search.types import (
|
||||||
SearchResult,
|
SearchResult,
|
||||||
CombinedSearchResult,
|
|
||||||
SearchResultDataset,
|
|
||||||
SearchType,
|
SearchType,
|
||||||
)
|
)
|
||||||
from cognee.modules.search.operations import log_query, log_result
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
|
|
@ -45,11 +43,10 @@ async def search(
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
use_combined_context: bool = False,
|
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
wide_search_top_k: Optional[int] = 100,
|
wide_search_top_k: Optional[int] = 100,
|
||||||
triplet_distance_penalty: Optional[float] = 3.5,
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> Union[CombinedSearchResult, List[SearchResult]]:
|
) -> List[SearchResult]:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -90,7 +87,6 @@ async def search(
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
use_combined_context=use_combined_context,
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
wide_search_top_k=wide_search_top_k,
|
wide_search_top_k=wide_search_top_k,
|
||||||
triplet_distance_penalty=triplet_distance_penalty,
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
|
|
@ -127,87 +123,59 @@ async def search(
|
||||||
query.id,
|
query.id,
|
||||||
json.dumps(
|
json.dumps(
|
||||||
jsonable_encoder(
|
jsonable_encoder(
|
||||||
await prepare_search_result(
|
[await prepare_search_result(search_result) for search_result in search_results]
|
||||||
search_results[0] if isinstance(search_results, list) else search_results
|
|
||||||
)
|
|
||||||
if use_combined_context
|
|
||||||
else [
|
|
||||||
await prepare_search_result(search_result) for search_result in search_results
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
user.id,
|
user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_combined_context:
|
# This is for maintaining backwards compatibility
|
||||||
prepared_search_results = await prepare_search_result(
|
if backend_access_control_enabled():
|
||||||
search_results[0] if isinstance(search_results, list) else search_results
|
return_value = []
|
||||||
)
|
for search_result in search_results:
|
||||||
result = prepared_search_results["result"]
|
prepared_search_results = await prepare_search_result(search_result)
|
||||||
graphs = prepared_search_results["graphs"]
|
|
||||||
context = prepared_search_results["context"]
|
|
||||||
datasets = prepared_search_results["datasets"]
|
|
||||||
|
|
||||||
return CombinedSearchResult(
|
result = prepared_search_results["result"]
|
||||||
result=result,
|
graphs = prepared_search_results["graphs"]
|
||||||
graphs=graphs,
|
context = prepared_search_results["context"]
|
||||||
context=context,
|
datasets = prepared_search_results["datasets"]
|
||||||
datasets=[
|
|
||||||
SearchResultDataset(
|
if only_context:
|
||||||
id=dataset.id,
|
return_value.append(
|
||||||
name=dataset.name,
|
{
|
||||||
|
"search_result": [context] if context else None,
|
||||||
|
"dataset_id": datasets[0].id,
|
||||||
|
"dataset_name": datasets[0].name,
|
||||||
|
"dataset_tenant_id": datasets[0].tenant_id,
|
||||||
|
"graphs": graphs,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
for dataset in datasets
|
else:
|
||||||
],
|
return_value.append(
|
||||||
)
|
{
|
||||||
|
"search_result": [result] if result else None,
|
||||||
|
"dataset_id": datasets[0].id,
|
||||||
|
"dataset_name": datasets[0].name,
|
||||||
|
"dataset_tenant_id": datasets[0].tenant_id,
|
||||||
|
"graphs": graphs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return return_value
|
||||||
else:
|
else:
|
||||||
# This is for maintaining backwards compatibility
|
return_value = []
|
||||||
if backend_access_control_enabled():
|
if only_context:
|
||||||
return_value = []
|
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
prepared_search_results = await prepare_search_result(search_result)
|
prepared_search_results = await prepare_search_result(search_result)
|
||||||
|
return_value.append(prepared_search_results["context"])
|
||||||
result = prepared_search_results["result"]
|
|
||||||
graphs = prepared_search_results["graphs"]
|
|
||||||
context = prepared_search_results["context"]
|
|
||||||
datasets = prepared_search_results["datasets"]
|
|
||||||
|
|
||||||
if only_context:
|
|
||||||
return_value.append(
|
|
||||||
{
|
|
||||||
"search_result": [context] if context else None,
|
|
||||||
"dataset_id": datasets[0].id,
|
|
||||||
"dataset_name": datasets[0].name,
|
|
||||||
"dataset_tenant_id": datasets[0].tenant_id,
|
|
||||||
"graphs": graphs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return_value.append(
|
|
||||||
{
|
|
||||||
"search_result": [result] if result else None,
|
|
||||||
"dataset_id": datasets[0].id,
|
|
||||||
"dataset_name": datasets[0].name,
|
|
||||||
"dataset_tenant_id": datasets[0].tenant_id,
|
|
||||||
"graphs": graphs,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return return_value
|
|
||||||
else:
|
else:
|
||||||
return_value = []
|
for search_result in search_results:
|
||||||
if only_context:
|
result, context, datasets = search_result
|
||||||
for search_result in search_results:
|
return_value.append(result)
|
||||||
prepared_search_results = await prepare_search_result(search_result)
|
# For maintaining backwards compatibility
|
||||||
return_value.append(prepared_search_results["context"])
|
if len(return_value) == 1 and isinstance(return_value[0], list):
|
||||||
else:
|
return return_value[0]
|
||||||
for search_result in search_results:
|
else:
|
||||||
result, context, datasets = search_result
|
return return_value
|
||||||
return_value.append(result)
|
|
||||||
# For maintaining backwards compatibility
|
|
||||||
if len(return_value) == 1 and isinstance(return_value[0], list):
|
|
||||||
return return_value[0]
|
|
||||||
else:
|
|
||||||
return return_value
|
|
||||||
|
|
||||||
|
|
||||||
async def authorized_search(
|
async def authorized_search(
|
||||||
|
|
@ -223,14 +191,10 @@ async def authorized_search(
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
only_context: bool = False,
|
only_context: bool = False,
|
||||||
use_combined_context: bool = False,
|
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
wide_search_top_k: Optional[int] = 100,
|
wide_search_top_k: Optional[int] = 100,
|
||||||
triplet_distance_penalty: Optional[float] = 3.5,
|
triplet_distance_penalty: Optional[float] = 3.5,
|
||||||
) -> Union[
|
) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]:
|
||||||
Tuple[Any, Union[List[Edge], str], List[Dataset]],
|
|
||||||
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
|
||||||
Not to be used outside of active access control mode.
|
Not to be used outside of active access control mode.
|
||||||
|
|
@ -240,70 +204,6 @@ async def authorized_search(
|
||||||
datasets=dataset_ids, permission_type="read", user=user
|
datasets=dataset_ids, permission_type="read", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_combined_context:
|
|
||||||
search_responses = await search_in_datasets_context(
|
|
||||||
search_datasets=search_datasets,
|
|
||||||
query_type=query_type,
|
|
||||||
query_text=query_text,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
last_k=last_k,
|
|
||||||
only_context=True,
|
|
||||||
session_id=session_id,
|
|
||||||
wide_search_top_k=wide_search_top_k,
|
|
||||||
triplet_distance_penalty=triplet_distance_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
context = {}
|
|
||||||
datasets: List[Dataset] = []
|
|
||||||
|
|
||||||
for _, search_context, search_datasets in search_responses:
|
|
||||||
for dataset in search_datasets:
|
|
||||||
context[str(dataset.id)] = search_context
|
|
||||||
|
|
||||||
datasets.extend(search_datasets)
|
|
||||||
|
|
||||||
specific_search_tools = await get_search_type_tools(
|
|
||||||
query_type=query_type,
|
|
||||||
query_text=query_text,
|
|
||||||
system_prompt_path=system_prompt_path,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
top_k=top_k,
|
|
||||||
node_type=node_type,
|
|
||||||
node_name=node_name,
|
|
||||||
save_interaction=save_interaction,
|
|
||||||
last_k=last_k,
|
|
||||||
wide_search_top_k=wide_search_top_k,
|
|
||||||
triplet_distance_penalty=triplet_distance_penalty,
|
|
||||||
)
|
|
||||||
search_tools = specific_search_tools
|
|
||||||
if len(search_tools) == 2:
|
|
||||||
[get_completion, _] = search_tools
|
|
||||||
else:
|
|
||||||
get_completion = search_tools[0]
|
|
||||||
|
|
||||||
def prepare_combined_context(
|
|
||||||
context,
|
|
||||||
) -> Union[List[Edge], str]:
|
|
||||||
combined_context = []
|
|
||||||
|
|
||||||
for dataset_context in context.values():
|
|
||||||
combined_context += dataset_context
|
|
||||||
|
|
||||||
if combined_context and isinstance(combined_context[0], str):
|
|
||||||
return "\n".join(combined_context)
|
|
||||||
|
|
||||||
return combined_context
|
|
||||||
|
|
||||||
combined_context = prepare_combined_context(context)
|
|
||||||
completion = await get_completion(query_text, combined_context, session_id=session_id)
|
|
||||||
|
|
||||||
return completion, combined_context, datasets
|
|
||||||
|
|
||||||
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
|
||||||
search_results = await search_in_datasets_context(
|
search_results = await search_in_datasets_context(
|
||||||
search_datasets=search_datasets,
|
search_datasets=search_datasets,
|
||||||
|
|
@ -319,6 +219,7 @@ async def authorized_search(
|
||||||
only_context=only_context,
|
only_context=only_context,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
wide_search_top_k=wide_search_top_k,
|
wide_search_top_k=wide_search_top_k,
|
||||||
|
triplet_distance_penalty=triplet_distance_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
return search_results
|
return search_results
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
class SearchResultDataset(BaseModel):
|
class SearchResultDataset(BaseModel):
|
||||||
|
|
@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class CombinedSearchResult(BaseModel):
|
|
||||||
result: Optional[Any]
|
|
||||||
context: Dict[str, Any]
|
|
||||||
graphs: Optional[Dict[str, Any]] = {}
|
|
||||||
datasets: Optional[List[SearchResultDataset]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(BaseModel):
|
class SearchResult(BaseModel):
|
||||||
search_result: Any
|
search_result: Any
|
||||||
dataset_id: Optional[UUID]
|
dataset_id: Optional[UUID]
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
from .SearchType import SearchType
|
from .SearchType import SearchType
|
||||||
from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult
|
from .SearchResult import SearchResult, SearchResultDataset
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue