feat: Enable multi-user mode to work with memify
This commit is contained in:
parent
e06cf11f49
commit
9e20103549
3 changed files with 24 additions and 22 deletions
|
|
@ -23,8 +23,8 @@ class MemifyPayloadDTO(InDTO):
|
|||
)
|
||||
enrichment_tasks: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||
data: Optional[str] = Field(default="")
|
||||
dataset_names: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||
dataset_ids: Optional[List[UUID]] = Field(default=None, examples=[[]])
|
||||
dataset_name: Optional[str] = Field(default=None)
|
||||
dataset_id: Optional[UUID] = Field(default=None, examples=[[""]])
|
||||
node_name: Optional[List[str]] = Field(default=None, examples=[[]])
|
||||
run_in_background: Optional[bool] = Field(default=False)
|
||||
|
||||
|
|
@ -46,8 +46,8 @@ def get_memify_router() -> APIRouter:
|
|||
- **data** Optional[List[str]]: The data to ingest. Can be any text data when custom extraction and enrichment tasks are used.
|
||||
Data provided here will be forwarded to the first extraction task in the pipeline as input.
|
||||
If no data is provided the whole graph (or subgraph if node_name/node_type is specified) will be forwarded
|
||||
- **dataset_names** (Optional[List[str]]): Name of the datasets to memify
|
||||
- **dataset_ids** (Optional[List[UUID]]): List of UUIDs of an already existing dataset
|
||||
- **dataset_name** (Optional[str]): Name of the datasets to memify
|
||||
- **dataset_id** (Optional[UUID]): List of UUIDs of an already existing dataset
|
||||
- **node_name** (Optional[List[str]]): Filter graph to specific named entities (for targeted search). Used when no data is provided.
|
||||
- **run_in_background** (Optional[bool]): Whether to execute processing asynchronously. Defaults to False (blocking).
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def get_memify_router() -> APIRouter:
|
|||
additional_properties={"endpoint": "POST /v1/memify"},
|
||||
)
|
||||
|
||||
if not payload.dataset_ids and not payload.dataset_names:
|
||||
if not payload.dataset_id and not payload.dataset_name:
|
||||
raise ValueError("Either datasetId or datasetName must be provided.")
|
||||
|
||||
try:
|
||||
|
|
@ -85,7 +85,7 @@ def get_memify_router() -> APIRouter:
|
|||
extraction_tasks=payload.extraction_tasks,
|
||||
enrichment_tasks=payload.enrichment_tasks,
|
||||
data=payload.data,
|
||||
datasets=payload.dataset_ids if payload.dataset_ids else payload.dataset_names,
|
||||
dataset=payload.dataset_id if payload.dataset_id else payload.dataset_name,
|
||||
node_name=payload.node_name,
|
||||
user=user,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
start_time = time.time()
|
||||
|
||||
# Determine projection strategy
|
||||
if node_type is not None and node_name not in [None, []]:
|
||||
if node_type is not None and node_name not in [None, [], ""]:
|
||||
nodes_data, edges_data = await adapter.get_nodeset_subgraph(
|
||||
node_type=node_type, node_name=node_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from uuid import UUID
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
|
||||
from cognee.modules.retrieval.utils.brute_force_triplet_search import get_memory_fragment
|
||||
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.pipelines import run_pipeline
|
||||
from cognee.modules.pipelines.tasks.task import Task
|
||||
|
|
@ -29,7 +29,7 @@ async def memify(
|
|||
extraction_tasks: Union[List[Task], List[str]] = None,
|
||||
enrichment_tasks: Union[List[Task], List[str]] = None,
|
||||
data: Optional[Any] = None,
|
||||
datasets: Union[str, list[str], list[UUID]] = None,
|
||||
dataset: Union[str, UUID] = "main_dataset",
|
||||
user: User = None,
|
||||
node_type: Optional[Type] = NodeSet,
|
||||
node_name: Optional[List[str]] = None,
|
||||
|
|
@ -53,10 +53,7 @@ async def memify(
|
|||
data: The data to ingest. Can be anything when custom extraction and enrichment tasks are used.
|
||||
Data provided here will be forwarded to the first extraction task in the pipeline as input.
|
||||
If no data is provided the whole graph (or subgraph if node_name/node_type is specified) will be forwarded
|
||||
datasets: Dataset name(s) or dataset uuid to process. Processes all available datasets if None.
|
||||
- Single dataset: "my_dataset"
|
||||
- Multiple datasets: ["docs", "research", "reports"]
|
||||
- None: Process all datasets for the user
|
||||
dataset: Dataset name or dataset uuid to process.
|
||||
user: User context for authentication and data access. Uses default if None.
|
||||
node_type: Filter graph to specific entity types (for advanced filtering). Used when no data is provided.
|
||||
node_name: Filter graph to specific named entities (for targeted search). Used when no data is provided.
|
||||
|
|
@ -80,7 +77,17 @@ async def memify(
|
|||
)
|
||||
]
|
||||
|
||||
await setup()
|
||||
|
||||
user, authorized_dataset_list = await resolve_authorized_user_datasets(dataset, user)
|
||||
authorized_dataset = authorized_dataset_list[0]
|
||||
|
||||
if not data:
|
||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||
await set_database_global_context_variables(
|
||||
authorized_dataset.id, authorized_dataset.owner_id
|
||||
)
|
||||
|
||||
memory_fragment = await get_memory_fragment(node_type=node_type, node_name=node_name)
|
||||
# Subgraphs should be a single element in the list to represent one data item
|
||||
data = [memory_fragment]
|
||||
|
|
@ -90,14 +97,9 @@ async def memify(
|
|||
*enrichment_tasks,
|
||||
]
|
||||
|
||||
await setup()
|
||||
|
||||
user, authorized_datasets = await resolve_authorized_user_datasets(datasets, user)
|
||||
|
||||
for dataset in authorized_datasets:
|
||||
await reset_dataset_pipeline_run_status(
|
||||
dataset.id, user, pipeline_names=["memify_pipeline"]
|
||||
)
|
||||
await reset_dataset_pipeline_run_status(
|
||||
authorized_dataset.id, user, pipeline_names=["memify_pipeline"]
|
||||
)
|
||||
|
||||
# By calling get pipeline executor we get a function that will have the run_pipeline run in the background or a function that we will need to wait for
|
||||
pipeline_executor_func = get_pipeline_executor(run_in_background=run_in_background)
|
||||
|
|
@ -108,7 +110,7 @@ async def memify(
|
|||
tasks=memify_tasks,
|
||||
user=user,
|
||||
data=data,
|
||||
datasets=datasets,
|
||||
datasets=authorized_dataset.id,
|
||||
vector_db_config=vector_db_config,
|
||||
graph_db_config=graph_db_config,
|
||||
incremental_loading=False,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue