fix: Resolve import issue with creating auth dataset
This commit is contained in:
parent
805f443cd6
commit
e06cf11f49
5 changed files with 28 additions and 13 deletions
|
|
@ -21,7 +21,7 @@ def get_add_router() -> APIRouter:
|
||||||
async def add(
|
async def add(
|
||||||
data: List[UploadFile] = File(default=None),
|
data: List[UploadFile] = File(default=None),
|
||||||
datasetName: Optional[str] = Form(default=None),
|
datasetName: Optional[str] = Form(default=None),
|
||||||
datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]),
|
datasetId: Union[UUID, None] = Form(default=None, examples=[""]),
|
||||||
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
node_set: Optional[List[str]] = Form(default=[""], example=[""]),
|
||||||
user: User = Depends(get_authenticated_user),
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,15 @@ logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
class MemifyPayloadDTO(InDTO):
|
class MemifyPayloadDTO(InDTO):
|
||||||
extraction_tasks: List[str] = Field(
|
extraction_tasks: Optional[List[str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
examples=[[]],
|
examples=[[]],
|
||||||
)
|
)
|
||||||
enrichment_tasks: List[str] = (Field(default=None, examples=[[]]),)
|
enrichment_tasks: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||||
data: Optional[str] = (Field(default=None),)
|
data: Optional[str] = Field(default="")
|
||||||
dataset_names: Optional[List[str]] = Field(default=None)
|
dataset_names: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||||
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
||||||
node_name: Optional[List[str]] = Field(default=None)
|
node_name: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||||
run_in_background: Optional[bool] = Field(default=False)
|
run_in_background: Optional[bool] = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -78,10 +78,10 @@ def get_memify_router() -> APIRouter:
|
||||||
if not payload.dataset_ids and not payload.dataset_names:
|
if not payload.dataset_ids and not payload.dataset_names:
|
||||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||||
|
|
||||||
from cognee import memify
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memify_run = await memify(
|
from cognee.modules.memify import memify as cognee_memify
|
||||||
|
|
||||||
|
memify_run = await cognee_memify(
|
||||||
extraction_tasks=payload.extraction_tasks,
|
extraction_tasks=payload.extraction_tasks,
|
||||||
enrichment_tasks=payload.enrichment_tasks,
|
enrichment_tasks=payload.enrichment_tasks,
|
||||||
data=payload.data,
|
data=payload.data,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from typing import List, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.modules.data.models import Dataset
|
from cognee.modules.data.models import Dataset
|
||||||
from cognee.modules.data.methods import create_authorized_dataset
|
from cognee.modules.data.methods.create_authorized_dataset import create_authorized_dataset
|
||||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,8 @@ logger = get_logger("memify")
|
||||||
|
|
||||||
|
|
||||||
async def memify(
|
async def memify(
|
||||||
extraction_tasks: Union[List[Task], List[str]] = [Task(extract_subgraph_chunks)],
|
extraction_tasks: Union[List[Task], List[str]] = None,
|
||||||
enrichment_tasks: Union[List[Task], List[str]] = [Task(add_rule_associations)],
|
enrichment_tasks: Union[List[Task], List[str]] = None,
|
||||||
data: Optional[Any] = None,
|
data: Optional[Any] = None,
|
||||||
datasets: Union[str, list[str], list[UUID]] = None,
|
datasets: Union[str, list[str], list[UUID]] = None,
|
||||||
user: User = None,
|
user: User = None,
|
||||||
|
|
@ -68,6 +68,18 @@ async def memify(
|
||||||
Use pipeline_run_id from return value to monitor progress.
|
Use pipeline_run_id from return value to monitor progress.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Use default coding rules tasks if no tasks were provided
|
||||||
|
if not extraction_tasks:
|
||||||
|
extraction_tasks = [Task(extract_subgraph_chunks)]
|
||||||
|
if not enrichment_tasks:
|
||||||
|
enrichment_tasks = [
|
||||||
|
Task(
|
||||||
|
add_rule_associations,
|
||||||
|
rules_nodeset_name="coding_agent_rules",
|
||||||
|
task_config={"batch_size": 1},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
memory_fragment = await get_memory_fragment(node_type=node_type, node_name=node_name)
|
memory_fragment = await get_memory_fragment(node_type=node_type, node_name=node_name)
|
||||||
# Subgraphs should be a single element in the list to represent one data item
|
# Subgraphs should be a single element in the list to represent one data item
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,11 @@ logger = get_logger("CodingRulesRetriever")
|
||||||
class CodingRulesRetriever:
|
class CodingRulesRetriever:
|
||||||
"""Retriever for handling codeing rule based searches."""
|
"""Retriever for handling codeing rule based searches."""
|
||||||
|
|
||||||
def __init__(self, rules_nodeset_name):
|
def __init__(self, rules_nodeset_name="coding_agent_rules"):
|
||||||
if isinstance(rules_nodeset_name, list):
|
if isinstance(rules_nodeset_name, list):
|
||||||
|
if not rules_nodeset_name:
|
||||||
|
# If there is no provided nodeset set to coding_agent_rules
|
||||||
|
rules_nodeset_name = ["coding_agent_rules"]
|
||||||
rules_nodeset_name = rules_nodeset_name[0]
|
rules_nodeset_name = rules_nodeset_name[0]
|
||||||
self.rules_nodeset_name = rules_nodeset_name
|
self.rules_nodeset_name = rules_nodeset_name
|
||||||
"""Initialize retriever with search parameters."""
|
"""Initialize retriever with search parameters."""
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue