From 9e201035493e6a38d614db9cbbd87b7d69a926d6 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Thu, 4 Sep 2025 20:59:00 +0200 Subject: [PATCH] feat: Enable multi-user mode to work with memify --- .../v1/memify/routers/get_memify_router.py | 12 +++---- .../modules/graph/cognee_graph/CogneeGraph.py | 2 +- cognee/modules/memify/memify.py | 32 ++++++++++--------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/cognee/api/v1/memify/routers/get_memify_router.py b/cognee/api/v1/memify/routers/get_memify_router.py index 817eef9bd..cf1df8f71 100644 --- a/cognee/api/v1/memify/routers/get_memify_router.py +++ b/cognee/api/v1/memify/routers/get_memify_router.py @@ -23,8 +23,8 @@ class MemifyPayloadDTO(InDTO): ) enrichment_tasks: Optional[List[str]] = Field(default=None, examples=[[]]) data: Optional[str] = Field(default="") - dataset_names: Optional[List[str]] = Field(default=None, examples=[[]]) - dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]]) + dataset_name: Optional[str] = Field(default=None) + dataset_id: Optional[UUID] = Field(default=None, examples=[[""]]) node_name: Optional[List[str]] = Field(default=None, examples=[[]]) run_in_background: Optional[bool] = Field(default=False) @@ -46,8 +46,8 @@ def get_memify_router() -> APIRouter: - **data** Optional[List[str]]: The data to ingest. Can be any text data when custom extraction and enrichment tasks are used. Data provided here will be forwarded to the first extraction task in the pipeline as input. If no data is provided the whole graph (or subgraph if node_name/node_type is specified) will be forwarded - - **dataset_names** (Optional[List[str]]): Name of the datasets to memify - - **dataset_ids** (Optional[List[UUID]]): List of UUIDs of an already existing dataset + - **dataset_name** (Optional[str]): Name of the datasets to memify + - **dataset_id** (Optional[UUID]): List of UUIDs of an already existing dataset - **node_name** (Optional[List[str]]): Filter graph to specific named entities (for targeted search). Used when no data is provided. - **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking). @@ -75,7 +75,7 @@ def get_memify_router() -> APIRouter: additional_properties={"endpoint": "POST /v1/memify"}, ) - if not payload.dataset_ids and not payload.dataset_names: + if not payload.dataset_id and not payload.dataset_name: raise ValueError("Either datasetId or datasetName must be provided.") try: @@ -85,7 +85,7 @@ def get_memify_router() -> APIRouter: extraction_tasks=payload.extraction_tasks, enrichment_tasks=payload.enrichment_tasks, data=payload.data, - datasets=payload.dataset_ids if payload.dataset_ids else payload.dataset_names, + dataset=payload.dataset_id if payload.dataset_id else payload.dataset_name, node_name=payload.node_name, user=user, ) diff --git a/cognee/modules/graph/cognee_graph/CogneeGraph.py b/cognee/modules/graph/cognee_graph/CogneeGraph.py index 924532ce0..acfe04de7 100644 --- a/cognee/modules/graph/cognee_graph/CogneeGraph.py +++ b/cognee/modules/graph/cognee_graph/CogneeGraph.py @@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph): start_time = time.time() # Determine projection strategy - if node_type is not None and node_name not in [None, []]: + if node_type is not None and node_name not in [None, [], ""]: nodes_data, edges_data = await adapter.get_nodeset_subgraph( node_type=node_type, node_name=node_name ) diff --git a/cognee/modules/memify/memify.py b/cognee/modules/memify/memify.py index d8e1087f2..2d9b32a1b 100644 --- a/cognee/modules/memify/memify.py +++ b/cognee/modules/memify/memify.py @@ -4,7 +4,7 @@ from uuid import UUID from cognee.shared.logging_utils import get_logger from cognee.modules.retrieval.utils.brute_force_triplet_search import get_memory_fragment - +from cognee.context_global_variables import set_database_global_context_variables from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.pipelines import run_pipeline from cognee.modules.pipelines.tasks.task import Task @@ -29,7 +29,7 @@ async def memify( extraction_tasks: Union[List[Task], List[str]] = None, enrichment_tasks: Union[List[Task], List[str]] = None, data: Optional[Any] = None, - datasets: Union[str, list[str], list[UUID]] = None, + dataset: Union[str, UUID] = "main_dataset", user: User = None, node_type: Optional[Type] = NodeSet, node_name: Optional[List[str]] = None, @@ -53,10 +53,7 @@ async def memify( data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used. Data provided here will be forwarded to the first extraction task in the pipeline as input. If no data is provided the whole graph (or subgraph if node_name/node_type is specified) will be forwarded - datasets: Dataset name(s) or dataset uuid to process. Processes all available datasets if None. - - Single dataset: "my_dataset" - - Multiple datasets: ["docs", "research", "reports"] - - None: Process all datasets for the user + dataset: Dataset name or dataset uuid to process. user: User context for authentication and data access. Uses default if None. node_type: Filter graph to specific entity types (for advanced filtering). Used when no data is provided. node_name: Filter graph to specific named entities (for targeted search). Used when no data is provided. @@ -80,7 +77,17 @@ async def memify( ) ] + await setup() + + user, authorized_dataset_list = await resolve_authorized_user_datasets(dataset, user) + authorized_dataset = authorized_dataset_list[0] + if not data: + # Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True + await set_database_global_context_variables( + authorized_dataset.id, authorized_dataset.owner_id + ) + memory_fragment = await get_memory_fragment(node_type=node_type, node_name=node_name) # Subgraphs should be a single element in the list to represent one data item data = [memory_fragment] @@ -90,14 +97,9 @@ async def memify( *enrichment_tasks, ] - await setup() - - user, authorized_datasets = await resolve_authorized_user_datasets(datasets, user) - - for dataset in authorized_datasets: - await reset_dataset_pipeline_run_status( - dataset.id, user, pipeline_names=["memify_pipeline"] - ) + await reset_dataset_pipeline_run_status( + authorized_dataset.id, user, pipeline_names=["memify_pipeline"] + ) # By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background) @@ -108,7 +110,7 @@ async def memify( tasks=memify_tasks, user=user, data=data, - datasets=datasets, + datasets=authorized_dataset.id, vector_db_config=vector_db_config, graph_db_config=graph_db_config, incremental_loading=False,