feat: Add nodeset as default node type
This commit is contained in:
parent
cf636ba77f
commit
21f688385b
3 changed files with 27 additions and 7 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Optional, List, Type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -17,7 +18,7 @@ async def search(
|
|||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
from uuid import UUID
|
||||
from typing import Callable, List, Optional, Type, Union
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
|
|
@ -38,7 +39,7 @@ async def search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
@ -67,6 +68,8 @@ async def search(
|
|||
dataset_ids=dataset_ids,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
|
@ -102,7 +105,7 @@ async def specific_search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
@ -173,6 +176,8 @@ async def authorized_search(
|
|||
dataset_ids: Optional[list[UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
) -> list:
|
||||
|
|
@ -194,7 +199,9 @@ async def authorized_search(
|
|||
user,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
save_interaction,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
||||
|
|
@ -210,6 +217,8 @@ async def specific_search_by_context(
|
|||
user: User,
|
||||
system_prompt_path: str,
|
||||
top_k: int,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
):
|
||||
|
|
@ -229,6 +238,8 @@ async def specific_search_by_context(
|
|||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
|
@ -243,7 +254,15 @@ async def specific_search_by_context(
|
|||
for dataset in search_datasets:
|
||||
tasks.append(
|
||||
_search_by_context(
|
||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
||||
dataset,
|
||||
user,
|
||||
query_type,
|
||||
query_text,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
last_k=last_k,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import uuid
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pylint.checkers.utils import node_type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.modules.search.methods.search import search, specific_search
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
|
@ -63,7 +63,7 @@ async def test_search(
|
|||
mock_user,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k=10,
|
||||
node_type=None,
|
||||
node_type=NodeSet,
|
||||
node_name=None,
|
||||
save_interaction=False,
|
||||
last_k=None,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue