feat: Enable nodesets on backend (#1314)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-08-29 16:21:42 +02:00 committed by GitHub
commit eb5631370e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 51 additions and 39 deletions

View file

@ -25,6 +25,7 @@ def get_add_router() -> APIRouter:
data: List[UploadFile] = File(default=None),
datasetName: Optional[str] = Form(default=None),
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
user: User = Depends(get_authenticated_user),
):
"""
@ -41,6 +42,8 @@ def get_add_router() -> APIRouter:
- Regular file uploads
- **datasetName** (Optional[str]): Name of the dataset to add data to
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
Used for grouping related data points in the knowledge graph.
Either datasetName or datasetId must be provided.
@ -65,9 +68,7 @@ def get_add_router() -> APIRouter:
send_telemetry(
"Add API Endpoint Invoked",
user.id,
additional_properties={
"endpoint": "POST /v1/add",
},
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
)
from cognee.api.v1.add import add as cognee_add
@ -76,34 +77,13 @@ def get_add_router() -> APIRouter:
raise ValueError("Either datasetId or datasetName must be provided.")
try:
if (
isinstance(data, str)
and data.startswith("http")
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
):
if "github" in data:
# Perform git clone if the URL is from GitHub
repo_name = data.split("/")[-1].replace(".git", "")
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
# TODO: Update add call with dataset info
await cognee_add(
"data://.data/",
f"{repo_name}",
)
else:
# Fetch and store the data from other types of URL using curl
response = requests.get(data)
response.raise_for_status()
add_run = await cognee_add(
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
)
file_data = await response.content()
# TODO: Update add call with dataset info
return await cognee_add(file_data)
else:
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
if isinstance(add_run, PipelineRunErrored):
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
return add_run.model_dump()
if isinstance(add_run, PipelineRunErrored):
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
return add_run.model_dump()
except Exception as error:
return JSONResponse(status_code=409, content={"error": str(error)})

View file

@ -20,6 +20,7 @@ class SearchPayloadDTO(InDTO):
datasets: Optional[list[str]] = Field(default=None)
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
query: str = Field(default="What is in the document?")
node_name: Optional[list[str]] = Field(default=None, example=[])
top_k: Optional[int] = Field(default=10)
@ -79,6 +80,7 @@ def get_search_router() -> APIRouter:
- **datasets** (Optional[List[str]]): List of dataset names to search within
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
- **query** (str): The search query string
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
## Response
@ -102,6 +104,7 @@ def get_search_router() -> APIRouter:
"datasets": payload.datasets,
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
"query": payload.query,
"node_name": payload.node_name,
"top_k": payload.top_k,
},
)
@ -115,6 +118,7 @@ def get_search_router() -> APIRouter:
user=user,
datasets=payload.datasets,
dataset_ids=payload.dataset_ids,
node_name=payload.node_name,
top_k=payload.top_k,
)

View file

@ -1,6 +1,7 @@
from uuid import UUID
from typing import Union, Optional, List, Type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.users.models import User
from cognee.modules.search.types import SearchType
from cognee.modules.users.methods import get_default_user
@ -17,7 +18,7 @@ async def search(
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,

View file

@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph):
start_time = time.time()
# Determine projection strategy
if node_type is not None and node_name is not None:
if node_type is not None and node_name not in [None, []]:
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
node_type=node_type, node_name=node_name
)

View file

@ -4,6 +4,7 @@ import asyncio
from uuid import UUID
from typing import Callable, List, Optional, Type, Union
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.context_global_variables import set_database_global_context_variables
@ -38,7 +39,7 @@ async def search(
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
@ -67,6 +68,8 @@ async def search(
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
@ -102,7 +105,7 @@ async def specific_search(
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: Optional[bool] = False,
last_k: Optional[int] = None,
@ -173,6 +176,8 @@ async def authorized_search(
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
) -> list:
@ -194,7 +199,9 @@ async def authorized_search(
user,
system_prompt_path,
top_k,
save_interaction,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
@ -210,6 +217,8 @@ async def specific_search_by_context(
user: User,
system_prompt_path: str,
top_k: int,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
):
@ -219,7 +228,15 @@ async def specific_search_by_context(
"""
async def _search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
dataset,
user,
query_type,
query_text,
system_prompt_path,
top_k,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
):
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
@ -229,6 +246,8 @@ async def specific_search_by_context(
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
@ -243,7 +262,15 @@ async def specific_search_by_context(
for dataset in search_datasets:
tasks.append(
_search_by_context(
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
dataset,
user,
query_type,
query_text,
system_prompt_path,
top_k,
node_type=node_type,
node_name=node_name,
last_k=last_k,
)
)

View file

@ -3,8 +3,8 @@ import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pylint.checkers.utils import node_type
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
from cognee.modules.search.methods.search import search, specific_search
from cognee.modules.search.types import SearchType
@ -63,7 +63,7 @@ async def test_search(
mock_user,
system_prompt_path="answer_simple_question.txt",
top_k=10,
node_type=None,
node_type=NodeSet,
node_name=None,
save_interaction=False,
last_k=None,