feat: Add nodeset as default node type
This commit is contained in:
parent
cf636ba77f
commit
21f688385b
3 changed files with 27 additions and 7 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Union, Optional, List, Type
|
from typing import Union, Optional, List, Type
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.models import User
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
from cognee.modules.users.methods import get_default_user
|
from cognee.modules.users.methods import get_default_user
|
||||||
|
|
@ -17,7 +18,7 @@ async def search(
|
||||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import Callable, List, Optional, Type, Union
|
from typing import Callable, List, Optional, Type, Union
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
|
|
@ -38,7 +39,7 @@ async def search(
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
|
@ -67,6 +68,8 @@ async def search(
|
||||||
dataset_ids=dataset_ids,
|
dataset_ids=dataset_ids,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
|
|
@ -102,7 +105,7 @@ async def specific_search(
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
node_type: Optional[Type] = None,
|
node_type: Optional[Type] = NodeSet,
|
||||||
node_name: Optional[List[str]] = None,
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: Optional[bool] = False,
|
save_interaction: Optional[bool] = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
|
|
@ -173,6 +176,8 @@ async def authorized_search(
|
||||||
dataset_ids: Optional[list[UUID]] = None,
|
dataset_ids: Optional[list[UUID]] = None,
|
||||||
system_prompt_path: str = "answer_simple_question.txt",
|
system_prompt_path: str = "answer_simple_question.txt",
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
|
|
@ -194,7 +199,9 @@ async def authorized_search(
|
||||||
user,
|
user,
|
||||||
system_prompt_path,
|
system_prompt_path,
|
||||||
top_k,
|
top_k,
|
||||||
save_interaction,
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -210,6 +217,8 @@ async def specific_search_by_context(
|
||||||
user: User,
|
user: User,
|
||||||
system_prompt_path: str,
|
system_prompt_path: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
node_type: Optional[Type] = NodeSet,
|
||||||
|
node_name: Optional[List[str]] = None,
|
||||||
save_interaction: bool = False,
|
save_interaction: bool = False,
|
||||||
last_k: Optional[int] = None,
|
last_k: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|
@ -229,6 +238,8 @@ async def specific_search_by_context(
|
||||||
user,
|
user,
|
||||||
system_prompt_path=system_prompt_path,
|
system_prompt_path=system_prompt_path,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
save_interaction=save_interaction,
|
save_interaction=save_interaction,
|
||||||
last_k=last_k,
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
|
|
@ -243,7 +254,15 @@ async def specific_search_by_context(
|
||||||
for dataset in search_datasets:
|
for dataset in search_datasets:
|
||||||
tasks.append(
|
tasks.append(
|
||||||
_search_by_context(
|
_search_by_context(
|
||||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
dataset,
|
||||||
|
user,
|
||||||
|
query_type,
|
||||||
|
query_text,
|
||||||
|
system_prompt_path,
|
||||||
|
top_k,
|
||||||
|
node_type=node_type,
|
||||||
|
node_name=node_name,
|
||||||
|
last_k=last_k,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ import uuid
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pylint.checkers.utils import node_type
|
|
||||||
|
|
||||||
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||||
from cognee.modules.search.methods.search import search, specific_search
|
from cognee.modules.search.methods.search import search, specific_search
|
||||||
from cognee.modules.search.types import SearchType
|
from cognee.modules.search.types import SearchType
|
||||||
|
|
@ -63,7 +63,7 @@ async def test_search(
|
||||||
mock_user,
|
mock_user,
|
||||||
system_prompt_path="answer_simple_question.txt",
|
system_prompt_path="answer_simple_question.txt",
|
||||||
top_k=10,
|
top_k=10,
|
||||||
node_type=None,
|
node_type=NodeSet,
|
||||||
node_name=None,
|
node_name=None,
|
||||||
save_interaction=False,
|
save_interaction=False,
|
||||||
last_k=None,
|
last_k=None,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue