diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 1aaed7f39..8b7a2f24b 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -6,7 +6,7 @@ from fastapi import Depends, APIRouter from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder -from cognee.modules.search.types import SearchType, SearchResult, CombinedSearchResult +from cognee.modules.search.types import SearchType, SearchResult from cognee.api.DTO import InDTO, OutDTO from cognee.modules.users.exceptions.exceptions import PermissionDeniedError, UserNotFoundError from cognee.modules.users.models import User @@ -31,7 +31,6 @@ class SearchPayloadDTO(InDTO): node_name: Optional[list[str]] = Field(default=None, example=[]) top_k: Optional[int] = Field(default=10) only_context: bool = Field(default=False) - use_combined_context: bool = Field(default=False) def get_search_router() -> APIRouter: @@ -74,7 +73,7 @@ def get_search_router() -> APIRouter: except Exception as error: return JSONResponse(status_code=500, content={"error": str(error)}) - @router.post("", response_model=Union[List[SearchResult], CombinedSearchResult, List]) + @router.post("", response_model=Union[List[SearchResult], List]) async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)): """ Search for nodes in the graph database. @@ -118,7 +117,6 @@ def get_search_router() -> APIRouter: "node_name": payload.node_name, "top_k": payload.top_k, "only_context": payload.only_context, - "use_combined_context": payload.use_combined_context, "cognee_version": cognee_version, }, ) @@ -136,7 +134,6 @@ def get_search_router() -> APIRouter: node_name=payload.node_name, top_k=payload.top_k, only_context=payload.only_context, - use_combined_context=payload.use_combined_context, ) return jsonable_encoder(results) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index ee7408758..b2fdfb8a5 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -4,7 +4,7 @@ from typing import Union, Optional, List, Type from cognee.infrastructure.databases.graph import get_graph_engine from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User -from cognee.modules.search.types import SearchResult, SearchType, CombinedSearchResult +from cognee.modules.search.types import SearchResult, SearchType from cognee.modules.users.methods import get_default_user from cognee.modules.search.methods import search as search_function from cognee.modules.data.methods import get_authorized_existing_datasets @@ -32,11 +32,10 @@ async def search( save_interaction: bool = False, last_k: Optional[int] = 1, only_context: bool = False, - use_combined_context: bool = False, session_id: Optional[str] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> Union[List[SearchResult], CombinedSearchResult]: +) -> List[SearchResult]: """ Search and query the knowledge graph for insights, information, and connections. @@ -214,7 +213,6 @@ async def search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, - use_combined_context=use_combined_context, session_id=session_id, wide_search_top_k=wide_search_top_k, triplet_distance_penalty=triplet_distance_penalty, diff --git a/cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py b/cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py index db748db64..2645c660f 100644 --- a/cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py +++ b/cognee/modules/notebooks/tutorials/python-development-with-cognee/cell-9.py @@ -27,6 +27,5 @@ await cognee.cognify(datasets=["python-development-with-cognee"], temporal_cogni results = await cognee.search( "What Python type hinting challenges did I face, and how does Guido approach similar problems in mypy?", datasets=["python-development-with-cognee"], - use_combined_context=True, # Used to show reasoning graph visualization ) print(results) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 9f180d607..39ae70d2c 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -14,8 +14,6 @@ from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.search.types import ( SearchResult, - CombinedSearchResult, - SearchResultDataset, SearchType, ) from cognee.modules.search.operations import log_query, log_result @@ -45,11 +43,10 @@ async def search( save_interaction: bool = False, last_k: Optional[int] = None, only_context: bool = False, - use_combined_context: bool = False, session_id: Optional[str] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> Union[CombinedSearchResult, List[SearchResult]]: +) -> List[SearchResult]: """ Args: @@ -90,7 +87,6 @@ async def search( save_interaction=save_interaction, last_k=last_k, only_context=only_context, - use_combined_context=use_combined_context, session_id=session_id, wide_search_top_k=wide_search_top_k, triplet_distance_penalty=triplet_distance_penalty, @@ -127,87 +123,59 @@ async def search( query.id, json.dumps( jsonable_encoder( - await prepare_search_result( - search_results[0] if isinstance(search_results, list) else search_results - ) - if use_combined_context - else [ - await prepare_search_result(search_result) for search_result in search_results - ] + [await prepare_search_result(search_result) for search_result in search_results] ) ), user.id, ) - if use_combined_context: - prepared_search_results = await prepare_search_result( - search_results[0] if isinstance(search_results, list) else search_results - ) - result = prepared_search_results["result"] - graphs = prepared_search_results["graphs"] - context = prepared_search_results["context"] - datasets = prepared_search_results["datasets"] + # This is for maintaining backwards compatibility + if backend_access_control_enabled(): + return_value = [] + for search_result in search_results: + prepared_search_results = await prepare_search_result(search_result) - return CombinedSearchResult( - result=result, - graphs=graphs, - context=context, - datasets=[ - SearchResultDataset( - id=dataset.id, - name=dataset.name, + result = prepared_search_results["result"] + graphs = prepared_search_results["graphs"] + context = prepared_search_results["context"] + datasets = prepared_search_results["datasets"] + + if only_context: + return_value.append( + { + "search_result": [context] if context else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + "graphs": graphs, + } ) - for dataset in datasets - ], - ) + else: + return_value.append( + { + "search_result": [result] if result else None, + "dataset_id": datasets[0].id, + "dataset_name": datasets[0].name, + "dataset_tenant_id": datasets[0].tenant_id, + "graphs": graphs, + } + ) + return return_value else: - # This is for maintaining backwards compatibility - if backend_access_control_enabled(): - return_value = [] + return_value = [] + if only_context: for search_result in search_results: prepared_search_results = await prepare_search_result(search_result) - - result = prepared_search_results["result"] - graphs = prepared_search_results["graphs"] - context = prepared_search_results["context"] - datasets = prepared_search_results["datasets"] - - if only_context: - return_value.append( - { - "search_result": [context] if context else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) - else: - return_value.append( - { - "search_result": [result] if result else None, - "dataset_id": datasets[0].id, - "dataset_name": datasets[0].name, - "dataset_tenant_id": datasets[0].tenant_id, - "graphs": graphs, - } - ) - return return_value + return_value.append(prepared_search_results["context"]) else: - return_value = [] - if only_context: - for search_result in search_results: - prepared_search_results = await prepare_search_result(search_result) - return_value.append(prepared_search_results["context"]) - else: - for search_result in search_results: - result, context, datasets = search_result - return_value.append(result) - # For maintaining backwards compatibility - if len(return_value) == 1 and isinstance(return_value[0], list): - return return_value[0] - else: - return return_value + for search_result in search_results: + result, context, datasets = search_result + return_value.append(result) + # For maintaining backwards compatibility + if len(return_value) == 1 and isinstance(return_value[0], list): + return return_value[0] + else: + return return_value async def authorized_search( @@ -223,14 +191,10 @@ async def authorized_search( save_interaction: bool = False, last_k: Optional[int] = None, only_context: bool = False, - use_combined_context: bool = False, session_id: Optional[str] = None, wide_search_top_k: Optional[int] = 100, triplet_distance_penalty: Optional[float] = 3.5, -) -> Union[ - Tuple[Any, Union[List[Edge], str], List[Dataset]], - List[Tuple[Any, Union[List[Edge], str], List[Dataset]]], -]: +) -> List[Tuple[Any, Union[List[Edge], str], List[Dataset]]]: """ Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset. Not to be used outside of active access control mode. @@ -240,70 +204,6 @@ async def authorized_search( datasets=dataset_ids, permission_type="read", user=user ) - if use_combined_context: - search_responses = await search_in_datasets_context( - search_datasets=search_datasets, - query_type=query_type, - query_text=query_text, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - last_k=last_k, - only_context=True, - session_id=session_id, - wide_search_top_k=wide_search_top_k, - triplet_distance_penalty=triplet_distance_penalty, - ) - - context = {} - datasets: List[Dataset] = [] - - for _, search_context, search_datasets in search_responses: - for dataset in search_datasets: - context[str(dataset.id)] = search_context - - datasets.extend(search_datasets) - - specific_search_tools = await get_search_type_tools( - query_type=query_type, - query_text=query_text, - system_prompt_path=system_prompt_path, - system_prompt=system_prompt, - top_k=top_k, - node_type=node_type, - node_name=node_name, - save_interaction=save_interaction, - last_k=last_k, - wide_search_top_k=wide_search_top_k, - triplet_distance_penalty=triplet_distance_penalty, - ) - search_tools = specific_search_tools - if len(search_tools) == 2: - [get_completion, _] = search_tools - else: - get_completion = search_tools[0] - - def prepare_combined_context( - context, - ) -> Union[List[Edge], str]: - combined_context = [] - - for dataset_context in context.values(): - combined_context += dataset_context - - if combined_context and isinstance(combined_context[0], str): - return "\n".join(combined_context) - - return combined_context - - combined_context = prepare_combined_context(context) - completion = await get_completion(query_text, combined_context, session_id=session_id) - - return completion, combined_context, datasets - # Searches all provided datasets and handles setting up of appropriate database context based on permissions search_results = await search_in_datasets_context( search_datasets=search_datasets, @@ -319,6 +219,7 @@ async def authorized_search( only_context=only_context, session_id=session_id, wide_search_top_k=wide_search_top_k, + triplet_distance_penalty=triplet_distance_penalty, ) return search_results diff --git a/cognee/modules/search/types/SearchResult.py b/cognee/modules/search/types/SearchResult.py index 8ea5d3990..828dde725 100644 --- a/cognee/modules/search/types/SearchResult.py +++ b/cognee/modules/search/types/SearchResult.py @@ -1,6 +1,6 @@ from uuid import UUID from pydantic import BaseModel -from typing import Any, Dict, List, Optional +from typing import Any, Optional class SearchResultDataset(BaseModel): @@ -8,13 +8,6 @@ class SearchResultDataset(BaseModel): name: str -class CombinedSearchResult(BaseModel): - result: Optional[Any] - context: Dict[str, Any] - graphs: Optional[Dict[str, Any]] = {} - datasets: Optional[List[SearchResultDataset]] = None - - class SearchResult(BaseModel): search_result: Any dataset_id: Optional[UUID] diff --git a/cognee/modules/search/types/__init__.py b/cognee/modules/search/types/__init__.py index 06e267f95..2e6466703 100644 --- a/cognee/modules/search/types/__init__.py +++ b/cognee/modules/search/types/__init__.py @@ -1,2 +1,2 @@ from .SearchType import SearchType -from .SearchResult import SearchResult, SearchResultDataset, CombinedSearchResult +from .SearchResult import SearchResult, SearchResultDataset