feat: Enable nodesets on backend
This commit is contained in:
parent
a2f8b594bd
commit
cf636ba77f
2 changed files with 11 additions and 30 deletions
|
|
@ -25,6 +25,7 @@ def get_add_router() -> APIRouter:
|
||||||
data: List[UploadFile] = File(default=None),
|
data: List[UploadFile] = File(default=None),
|
||||||
datasetName: Optional[str] = Form(default=None),
|
datasetName: Optional[str] = Form(default=None),
|
||||||
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
||||||
|
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -65,9 +66,7 @@ def get_add_router() -> APIRouter:
|
||||||
send_telemetry(
|
send_telemetry(
|
||||||
"Add API Endpoint Invoked",
|
"Add API Endpoint Invoked",
|
||||||
user.id,
|
user.id,
|
||||||
additional_properties={
|
additional_properties={"endpoint": "POST /v1/add", "node_set": node_set},
|
||||||
"endpoint": "POST /v1/add",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.api.v1.add import add as cognee_add
|
from cognee.api.v1.add import add as cognee_add
|
||||||
|
|
@ -76,34 +75,13 @@ def get_add_router() -> APIRouter:
|
||||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (
|
add_run = await cognee_add(
|
||||||
isinstance(data, str)
|
data, datasetName, user=user, dataset_id=datasetId, node_set=node_set
|
||||||
and data.startswith("http")
|
)
|
||||||
and (os.getenv("ALLOW_HTTP_REQUESTS", "true").lower() == "true")
|
|
||||||
):
|
|
||||||
if "github" in data:
|
|
||||||
# Perform git clone if the URL is from GitHub
|
|
||||||
repo_name = data.split("/")[-1].replace(".git", "")
|
|
||||||
subprocess.run(["git", "clone", data, f".data/{repo_name}"], check=True)
|
|
||||||
# TODO: Update add call with dataset info
|
|
||||||
await cognee_add(
|
|
||||||
"data://.data/",
|
|
||||||
f"{repo_name}",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Fetch and store the data from other types of URL using curl
|
|
||||||
response = requests.get(data)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
file_data = await response.content()
|
if isinstance(add_run, PipelineRunErrored):
|
||||||
# TODO: Update add call with dataset info
|
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
||||||
return await cognee_add(file_data)
|
return add_run.model_dump()
|
||||||
else:
|
|
||||||
add_run = await cognee_add(data, datasetName, user=user, dataset_id=datasetId)
|
|
||||||
|
|
||||||
if isinstance(add_run, PipelineRunErrored):
|
|
||||||
return JSONResponse(status_code=420, content=add_run.model_dump(mode="json"))
|
|
||||||
return add_run.model_dump()
|
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(status_code=409, content={"error": str(error)})
|
return JSONResponse(status_code=409, content={"error": str(error)})
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ class SearchPayloadDTO(InDTO):
|
||||||
datasets: Optional[list[str]] = Field(default=None)
|
datasets: Optional[list[str]] = Field(default=None)
|
||||||
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
dataset_ids: Optional[list[UUID]] = Field(default=None, examples=[[]])
|
||||||
query: str = Field(default="What is in the document?")
|
query: str = Field(default="What is in the document?")
|
||||||
|
node_name: Optional[list[str]] = Field(default=None, example=[])
|
||||||
top_k: Optional[int] = Field(default=10)
|
top_k: Optional[int] = Field(default=10)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -102,6 +103,7 @@ def get_search_router() -> APIRouter:
|
||||||
"datasets": payload.datasets,
|
"datasets": payload.datasets,
|
||||||
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
"dataset_ids": [str(dataset_id) for dataset_id in payload.dataset_ids or []],
|
||||||
"query": payload.query,
|
"query": payload.query,
|
||||||
|
"node_name": payload.node_name,
|
||||||
"top_k": payload.top_k,
|
"top_k": payload.top_k,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -115,6 +117,7 @@ def get_search_router() -> APIRouter:
|
||||||
user=user,
|
user=user,
|
||||||
datasets=payload.datasets,
|
datasets=payload.datasets,
|
||||||
dataset_ids=payload.dataset_ids,
|
dataset_ids=payload.dataset_ids,
|
||||||
|
node_name=payload.node_name,
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue