diff --git a/cognee/api/v1/add/routers/get_add_router.py b/cognee/api/v1/add/routers/get_add_router.py index 9de818b7d..f27d559e1 100644 --- a/cognee/api/v1/add/routers/get_add_router.py +++ b/cognee/api/v1/add/routers/get_add_router.py @@ -21,7 +21,7 @@ def get_add_router() -> APIRouter: async def add( data: List[UploadFile] = File(default=None), datasetName: Optional[str] = Form(default=None), - datasetId: Union[UUID, Literal[""], None] = Form(default=None, examples=[""]), + datasetId: Union[UUID, None] = Form(default=None, examples=[""]), node_set: Optional[List[str]] = Form(default=[""], example=[""]), user: User = Depends(get_authenticated_user), ): diff --git a/cognee/api/v1/memify/routers/get_memify_router.py b/cognee/api/v1/memify/routers/get_memify_router.py index edac2775a..817eef9bd 100644 --- a/cognee/api/v1/memify/routers/get_memify_router.py +++ b/cognee/api/v1/memify/routers/get_memify_router.py @@ -17,15 +17,15 @@ logger = get_logger() class MemifyPayloadDTO(InDTO): - extraction_tasks: List[str] = Field( + extraction_tasks: Optional[List[str]] = Field( default=None, examples=[[]], ) - enrichment_tasks: List[str] = (Field(default=None, examples=[[]]),) - data: Optional[str] = (Field(default=None),) - dataset_names: Optional[List[str]] = Field(default=None) + enrichment_tasks: Optional[List[str]] = Field(default=None, examples=[[]]) + data: Optional[str] = Field(default="") + dataset_names: Optional[List[str]] = Field(default=None, examples=[[]]) dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) - node_name: Optional[List[str]] = Field(default=None) + node_name: Optional[List[str]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) @@ -78,10 +78,10 @@ def get_memify_router() -> APIRouter: if not payload.dataset_ids and not payload.dataset_names: raise ValueError("Either datasetId or datasetName must be provided.") - from cognee import memify - try: - memify_run = await memify( + from cognee.modules.memify import memify as cognee_memify + + memify_run = await cognee_memify( extraction_tasks=payload.extraction_tasks, enrichment_tasks=payload.enrichment_tasks, data=payload.data, diff --git a/cognee/modules/data/methods/load_or_create_datasets.py b/cognee/modules/data/methods/load_or_create_datasets.py index 1d6ef3efb..2c9a6497c 100644 --- a/cognee/modules/data/methods/load_or_create_datasets.py +++ b/cognee/modules/data/methods/load_or_create_datasets.py @@ -2,7 +2,7 @@ from typing import List, Union from uuid import UUID from cognee.modules.data.models import Dataset -from cognee.modules.data.methods import create_authorized_dataset +from cognee.modules.data.methods.create_authorized_dataset import create_authorized_dataset from cognee.modules.data.exceptions import DatasetNotFoundError diff --git a/cognee/modules/memify/memify.py b/cognee/modules/memify/memify.py index 80afd7325..d8e1087f2 100644 --- a/cognee/modules/memify/memify.py +++ b/cognee/modules/memify/memify.py @@ -26,8 +26,8 @@ logger = get_logger("memify") async def memify( - extraction_tasks: Union[List[Task], List[str]] = [Task(extract_subgraph_chunks)], - enrichment_tasks: Union[List[Task], List[str]] = [Task(add_rule_associations)], + extraction_tasks: Union[List[Task], List[str]] = None, + enrichment_tasks: Union[List[Task], List[str]] = None, data: Optional[Any] = None, datasets: Union[str, list[str], list[UUID]] = None, user: User = None, @@ -68,6 +68,18 @@ async def memify( Use pipeline_run_id from return value to monitor progress. """ + # Use default coding rules tasks if no tasks were provided + if not extraction_tasks: + extraction_tasks = [Task(extract_subgraph_chunks)] + if not enrichment_tasks: + enrichment_tasks = [ + Task( + add_rule_associations, + rules_nodeset_name="coding_agent_rules", + task_config={"batch_size": 1}, + ) + ] + if not data: memory_fragment = await get_memory_fragment(node_type=node_type, node_name=node_name) # Subgraphs should be a single element in the list to represent one data item diff --git a/cognee/modules/retrieval/coding_rules_retriever.py b/cognee/modules/retrieval/coding_rules_retriever.py index 2578d1ee1..364ff3236 100644 --- a/cognee/modules/retrieval/coding_rules_retriever.py +++ b/cognee/modules/retrieval/coding_rules_retriever.py @@ -7,8 +7,11 @@ logger = get_logger("CodingRulesRetriever") class CodingRulesRetriever: """Retriever for handling codeing rule based searches.""" - def __init__(self, rules_nodeset_name): + def __init__(self, rules_nodeset_name="coding_agent_rules"): if isinstance(rules_nodeset_name, list): + if not rules_nodeset_name: + # If there is no provided nodeset set to coding_agent_rules + rules_nodeset_name = ["coding_agent_rules"] rules_nodeset_name = rules_nodeset_name[0] self.rules_nodeset_name = rules_nodeset_name """Initialize retriever with search parameters."""