feat: Enable nodesets on backend (#1314)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
eb5631370e
6 changed files with 51 additions and 39 deletions
|
|
@ -25,6 +25,7 @@ def get_add_router() -> APIRouter:
|
|||
data: List[UploadFile] = File(default=None),
|
||||
datasetName: Optional[str] = Form(default=None),
|
||||
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
||||
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
||||
user: User = Depends(get_authenticated_user),
|
||||
):
|
||||
"""
|
||||
|
|
@ -41,6 +42,8 @@ def get_add_router() -> APIRouter:
|
|||
- Regular file uploads
|
||||
- **datasetName** (Optional[str]): Name of the dataset to add data to
|
||||
- **datasetId** (Optional[UUID]): UUID of an already existing dataset
|
||||
- **node_set** Optional[list[str]]: List of node identifiers for graph organization and access control.
|
||||
Used for grouping related data points in the knowledge graph.
|
||||
|
||||
Either datasetName or datasetId must be provided.
|
||||
|
||||
|
|
@ -65,9 +68,7 @@ def get_add_router() -> APIRouter:
|
|||
send_telemetry(
|
||||
"Add API Endpoint Invoked",
|
||||
user.id,
|
||||
additional_properties={
|
||||
"endpoint": "POST /v1/add",
|
||||
},
|
||||
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
|
||||
)
|
||||
|
||||
from cognee.api.v1.add import add as cognee_add
|
||||
|
|
@ -76,34 +77,13 @@ def get_add_router() -> APIRouter:
|
|||
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||
|
||||
try:
|
||||
if (
|
||||
isinstance(data, str)
|
||||
and data.startswith("http")
|
||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
||||
):
|
||||
if "github" in data:
|
||||
# Perform git clone if the URL is from GitHub
|
||||
repo_name = data.split("/")[-1].replace(".git", "")
|
||||
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
|
||||
# TODO: Update add call with dataset info
|
||||
await cognee_add(
|
||||
"data://.data/",
|
||||
f"{repo_name}",
|
||||
)
|
||||
else:
|
||||
# Fetch and store the data from other types of URL using curl
|
||||
response = requests.get(data)
|
||||
response.raise_for_status()
|
||||
add_run = await cognee_add(
|
||||
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
|
||||
)
|
||||
|
||||
file_data = await response.content()
|
||||
# TODO: Update add call with dataset info
|
||||
return await cognee_add(file_data)
|
||||
else:
|
||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
||||
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
||||
return add_run.model_dump()
|
||||
if isinstance(add_run, PipelineRunErrored):
|
||||
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
||||
return add_run.model_dump()
|
||||
except Exception as error:
|
||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class SearchPayloadDTO(InDTO):
|
|||
datasets: Optional[list[str]] = Field(default=None)
|
||||
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
||||
query: str = Field(default="What is in the document?")
|
||||
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||
top_k: Optional[int] = Field(default=10)
|
||||
|
||||
|
||||
|
|
@ -79,6 +80,7 @@ def get_search_router() -> APIRouter:
|
|||
- **datasets** (Optional[List[str]]): List of dataset names to search within
|
||||
- **dataset_ids** (Optional[List[UUID]]): List of dataset UUIDs to search within
|
||||
- **query** (str): The search query string
|
||||
- **node_name** Optional[list[str]]: Filter results to specific node_sets defined in the add pipeline (for targeted search).
|
||||
- **top_k** (Optional[int]): Maximum number of results to return (default: 10)
|
||||
|
||||
## Response
|
||||
|
|
@ -102,6 +104,7 @@ def get_search_router() -> APIRouter:
|
|||
"datasets": payload.datasets,
|
||||
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
||||
"query": payload.query,
|
||||
"node_name": payload.node_name,
|
||||
"top_k": payload.top_k,
|
||||
},
|
||||
)
|
||||
|
|
@ -115,6 +118,7 @@ def get_search_router() -> APIRouter:
|
|||
user=user,
|
||||
datasets=payload.datasets,
|
||||
dataset_ids=payload.dataset_ids,
|
||||
node_name=payload.node_name,
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Union, Optional, List, Type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.search.types import SearchType
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -17,7 +18,7 @@ async def search(
|
|||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name is not None:
|
||||
if node_type is not None and node_name not in [None, []]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import asyncio
|
|||
from uuid import UUID
|
||||
from typing import Callable, List, Optional, Type, Union
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
|
|
@ -38,7 +39,7 @@ async def search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
@ -67,6 +68,8 @@ async def search(
|
|||
dataset_ids=dataset_ids,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
|
@ -102,7 +105,7 @@ async def specific_search(
|
|||
user: User,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: Optional[bool] = False,
|
||||
last_k: Optional[int] = None,
|
||||
|
|
@ -173,6 +176,8 @@ async def authorized_search(
|
|||
dataset_ids: Optional[list[UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
) -> list:
|
||||
|
|
@ -194,7 +199,9 @@ async def authorized_search(
|
|||
user,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
save_interaction,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
||||
|
|
@ -210,6 +217,8 @@ async def specific_search_by_context(
|
|||
user: User,
|
||||
system_prompt_path: str,
|
||||
top_k: int,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
save_interaction: bool = False,
|
||||
last_k: Optional[int] = None,
|
||||
):
|
||||
|
|
@ -219,7 +228,15 @@ async def specific_search_by_context(
|
|||
"""
|
||||
|
||||
async def _search_by_context(
|
||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
||||
dataset,
|
||||
user,
|
||||
query_type,
|
||||
query_text,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
last_k: Optional[int] = None,
|
||||
):
|
||||
# Set database configuration in async context for each dataset user has access for
|
||||
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||
|
|
@ -229,6 +246,8 @@ async def specific_search_by_context(
|
|||
user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
save_interaction=save_interaction,
|
||||
last_k=last_k,
|
||||
)
|
||||
|
|
@ -243,7 +262,15 @@ async def specific_search_by_context(
|
|||
for dataset in search_datasets:
|
||||
tasks.append(
|
||||
_search_by_context(
|
||||
dataset, user, query_type, query_text, system_prompt_path, top_k, last_k
|
||||
dataset,
|
||||
user,
|
||||
query_type,
|
||||
query_text,
|
||||
system_prompt_path,
|
||||
top_k,
|
||||
node_type=node_type,
|
||||
node_name=node_name,
|
||||
last_k=last_k,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import uuid
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pylint.checkers.utils import node_type
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.search.exceptions import UnsupportedSearchTypeError
|
||||
from cognee.modules.search.methods.search import search, specific_search
|
||||
from cognee.modules.search.types import SearchType
|
||||
|
|
@ -63,7 +63,7 @@ async def test_search(
|
|||
mock_user,
|
||||
system_prompt_path="answer_simple_question.txt",
|
||||
top_k=10,
|
||||
node_type=None,
|
||||
node_type=NodeSet,
|
||||
node_name=None,
|
||||
save_interaction=False,
|
||||
last_k=None,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue