diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 66b165a38..1703d9931 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -25,6 +25,7 @@ def get_add_router() -> APIRouter: data: List[UploadFile] = File(default=None), datasetName: Optional[str] = Form(default=None), datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]), + node_set: Optional[List[str]] = Form(default=[""], example=[""]), user: User = Depends(get_authenticated_user), ): """ @@ -41,6 +42,8 @@ def get_add_router() -> APIRouter: - Regular file uploads - **datasetName** (Optional[str]): Name of the dataset to add data to - **datasetId** (Optional[UUID]): UUID of an already existing dataset + - **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control. + Used for grouping related data points in the knowledge graph. Either datasetName or datasetId must be provided. @@ -65,9 +68,7 @@ def get_add_router() -> APIRouter: send_telemetry( "Add API Endpoint Invoked", user.id, - additional_properties={ - "endpoint": "POST /v1/add", - }, + additional_properties={"endpoint": "POST /v1/add", "node_set": node_set}, ) from cognee.api.v1.add import add as cognee_add @@ -76,34 +77,13 @@ def get_add_router() -> APIRouter: raise ValueError("Either datasetId or datasetName must be provided.") try: - if ( - isinstance(data, str) - and data.startswith("http") - and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true") - ): - if "github" in data: - # Perform git clone if the URL is from GitHub - repo_name = data.split("/")[-1].replace(".git", "") - subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True) - # TODO: Update add call with dataset info - await cognee_add( - "data://.data/", - f"{repo_name}", - ) - else: - # Fetch and store the data from other types of URL using curl - response = requests.get(data) - response.raise_for_status() + add_run = await cognee_add( + data, datasetName, user=user, dataset_id=datasetId, node_set=node_set + ) - file_data = await response.content() - # TODO: Update add call with dataset info - return await cognee_add(file_data) - else: - add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId) - - if isinstance(add_run, PipelineRunErrored): - return JSONResponse(status_code=420, content=add_run.model_dump(mode="json")) - return add_run.model_dump() + if isinstance(add_run, PipelineRunErrored): + return JSONResponse(status_code=420, content=add_run.model_dump(mode="json")) + return add_run.model_dump() except Exception as error: return JSONResponse(status_code=409, content={"error": str(error)}) diff --git a/cognee/api/v1/search/routers/get_search_router.py b/cognee/api/v1/search/routers/get_search_router.py index 0ceeb1abb..003df7cd4 100644 --- a/cognee/api/v1/search/routers/get_search_router.py +++ b/cognee/api/v1/search/routers/get_search_router.py @@ -20,6 +20,7 @@ class SearchPayloadDTO(InDTO): datasets: Optional[list[str]] = Field(default=None) dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]]) query: str = Field(default="What is in the document?") + node_name: Optional[list[str]] = Field(default=None, example=[]) top_k: Optional[int] = Field(default=10) @@ -79,6 +80,7 @@ def get_search_router() -> APIRouter: - **datasets** (Optional[List[str]]): List of dataset names to search within - **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within - **query** (str): The search query string + - **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search). - **top_k** (Optional[int]): Maximum number of results to return (default: 10) ## Response @@ -102,6 +104,7 @@ def get_search_router() -> APIRouter: "datasets": payload.datasets, "dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []], "query": payload.query, + "node_name": payload.node_name, "top_k": payload.top_k, }, ) @@ -115,6 +118,7 @@ def get_search_router() -> APIRouter: user=user, datasets=payload.datasets, dataset_ids=payload.dataset_ids, + node_name=payload.node_name, top_k=payload.top_k, ) diff --git a/cognee/api/v1/search/search.py b/cognee/api/v1/search/search.py index f37f8ba6d..344e763ae 100644 --- a/cognee/api/v1/search/search.py +++ b/cognee/api/v1/search/search.py @@ -1,6 +1,7 @@ from uuid import UUID from typing import Union, Optional, List, Type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.users.models import User from cognee.modules.search.types import SearchType from cognee.modules.users.methods import get_default_user @@ -17,7 +18,7 @@ async def search( dataset_ids: Optional[Union[list[UUID], UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index ed867ae24..924532ce0 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph): start_time = time.time() # Determine projection strategy - if node_type is not None and node_name is not None: + if node_type is not None and node_name not in [None, []]: nodes_data, edges_data = await adapter.get_nodeset_subgraph( node_type=node_type, node_name=node_name ) diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index f5f2a793a..74ef2a6ad 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -4,6 +4,7 @@ import asyncio from uuid import UUID from typing import Callable, List, Optional, Type, Union +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.context_global_variables import set_database_global_context_variables @@ -38,7 +39,7 @@ async def search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, @@ -67,6 +68,8 @@ async def search( dataset_ids=dataset_ids, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, ) @@ -102,7 +105,7 @@ async def specific_search( user: User, system_prompt_path="answer_simple_question.txt", top_k: int = 10, - node_type: Optional[Type] = None, + node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, save_interaction: Optional[bool] = False, last_k: Optional[int] = None, @@ -173,6 +176,8 @@ async def authorized_search( dataset_ids: Optional[list[UUID]] = None, system_prompt_path: str = "answer_simple_question.txt", top_k: int = 10, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, ) -> list: @@ -194,7 +199,9 @@ async def authorized_search( user, system_prompt_path, top_k, - save_interaction, + node_type=node_type, + node_name=node_name, + save_interaction=save_interaction, last_k=last_k, ) @@ -210,6 +217,8 @@ async def specific_search_by_context( user: User, system_prompt_path: str, top_k: int, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, save_interaction: bool = False, last_k: Optional[int] = None, ): @@ -219,7 +228,15 @@ async def specific_search_by_context( """ async def _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset, + user, + query_type, + query_text, + system_prompt_path, + top_k, + node_type: Optional[Type] = NodeSet, + node_name: Optional[List[str]] = None, + last_k: Optional[int] = None, ): # Set database configuration in async context for each dataset user has access for await set_database_global_context_variables(dataset.id, dataset.owner_id) @@ -229,6 +246,8 @@ async def specific_search_by_context( user, system_prompt_path=system_prompt_path, top_k=top_k, + node_type=node_type, + node_name=node_name, save_interaction=save_interaction, last_k=last_k, ) @@ -243,7 +262,15 @@ async def specific_search_by_context( for dataset in search_datasets: tasks.append( _search_by_context( - dataset, user, query_type, query_text, system_prompt_path, top_k, last_k + dataset, + user, + query_type, + query_text, + system_prompt_path, + top_k, + node_type=node_type, + node_name=node_name, + last_k=last_k, ) ) diff --git a/cognee/tests/unit/modules/search/search_methods_test.py b/cognee/tests/unit/modules/search/search_methods_test.py index 46995d087..004e1fca3 100644 --- a/cognee/tests/unit/modules/search/search_methods_test.py +++ b/cognee/tests/unit/modules/search/search_methods_test.py @@ -3,8 +3,8 @@ import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest -from pylint.checkers.utils import node_type +from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.search.exceptions import UnsupportedSearchTypeError from cognee.modules.search.methods.search import search, specific_search from cognee.modules.search.types import SearchType @@ -63,7 +63,7 @@ async def test_search( mock_user, system_prompt_path="answer_simple_question.txt", top_k=10, - node_type=None, + node_type=NodeSet, node_name=None, save_interaction=False, last_k=None,