From 21f688385b16cc3bc50d355b32eb4b7610df2053 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Fri, 29 Aug 2025 12:53:29 +0200 Subject: [PATCH] feat: Add nodeset as default node type --- cognee/api/v1/search/search.py | 3 ++- cognee/modules/search/methods/search.py | 27 ++++++++++++++++--- .../modules/search/search_methods_test.py | 4 +-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index f37f8ba6d..344e763ae 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,6 +1,7 @@ from uuid import UUID from typing import Union, Optional, List, Type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user @@ -17,7 +18,7 @@ async def search( dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index f5f2a793a..8e38e63c3 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -4,6 +4,7 @@ import asyncio from uuid import UUID from typing import Callable, List, Optional, Type, Union +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.context_global_variables import set_database_global_context_variables @@ -38,7 +39,7 @@ async def search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, @@ -67,6 +68,8 @@ async def search( dataset_ids=dataset_ids, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, ) @@ -102,7 +105,7 @@ async def specific_search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, @@ -173,6 +176,8 @@ async def authorized_search( dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, ) -> list: @@ -194,7 +199,9 @@ async def authorized_search( user, system_prompt_path, top_k, - save_interaction, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, last_k=last_k, ) @@ -210,6 +217,8 @@ async def specific_search_by_context( user: User, system_prompt_path: str, top_k: int, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, ): @@ -229,6 +238,8 @@ async def specific_search_by_context( user, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, ) @@ -243,7 +254,15 @@ async def specific_search_by_context( for dataset in search_datasets: tasks.append( _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset, + user, + query_type, + query_text, + system_prompt_path, + top_k, + node_type=node_type, + node_name=node_name, + last_k=last_k, ) ) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 46995d087..004e1fca3 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -3,8 +3,8 @@ import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest -from pylint.checkers.utils import node_type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.methods.search import search, specific_search from cognee.modules.search.types import SearchType @@ -63,7 +63,7 @@ async def test_search( mock_user, system_prompt_path="answer_simple_question.txt", top_k=10, - node_type=None, + node_type=NodeSet, node_name=None, save_interaction=False, last_k=None,