From 630ab556dbc24d8002745bc4ee082121032b1eb4 Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Tue, 17 Dec 2024 11:20:22 +0100 Subject: [PATCH] feat: Add search by dataset for cognee Added ability to search by datasets for cognee users Feature COG-912 --- cognee/api/v1/search/search_v2.py | 11 +++++-- cognee/modules/data/models/__init__.py | 1 + .../methods/get_document_ids_for_user.py | 31 ++++++++++++++++++- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/cognee/api/v1/search/search_v2.py b/cognee/api/v1/search/search_v2.py index 6a5da4648..222ec6791 100644 --- a/cognee/api/v1/search/search_v2.py +++ b/cognee/api/v1/search/search_v2.py @@ -1,7 +1,7 @@ import json from uuid import UUID from enum import Enum -from typing import Callable, Dict +from typing import Callable, Dict, Union from cognee.exceptions import InvalidValueError from cognee.modules.search.operations import log_query, log_result @@ -22,7 +22,12 @@ class SearchType(Enum): CHUNKS = "CHUNKS" COMPLETION = "COMPLETION" -async def search(query_type: SearchType, query_text: str, user: User = None) -> list: +async def search(query_type: SearchType, query_text: str, user: User = None, + datasets: Union[list[str], str, None] = None) -> list: + # We use lists from now on for datasets + if isinstance(datasets, str): + datasets = [datasets] + if user is None: user = await get_default_user() @@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) -> query = await log_query(query_text, str(query_type), user.id) - own_document_ids = await get_document_ids_for_user(user.id) + own_document_ids = await get_document_ids_for_user(user.id, datasets) search_results = await specific_search(query_type, query_text, user) filtered_search_results = [] diff --git a/cognee/modules/data/models/__init__.py b/cognee/modules/data/models/__init__.py index 5d79dbd40..bd5774f88 100644 --- a/cognee/modules/data/models/__init__.py +++ b/cognee/modules/data/models/__init__.py @@ -1,2 +1,3 @@ from .Data import Data from .Dataset import Dataset +from .DatasetData import DatasetData diff --git a/cognee/modules/users/permissions/methods/get_document_ids_for_user.py b/cognee/modules/users/permissions/methods/get_document_ids_for_user.py index 79736db0f..7e052ebc9 100644 --- a/cognee/modules/users/permissions/methods/get_document_ids_for_user.py +++ b/cognee/modules/users/permissions/methods/get_document_ids_for_user.py @@ -1,9 +1,11 @@ from uuid import UUID from sqlalchemy import select from cognee.infrastructure.databases.relational import get_relational_engine +from cognee.modules.data.models import Dataset, DatasetData from ...models import ACL, Resource, Permission -async def get_document_ids_for_user(user_id: UUID) -> list[str]: + +async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]: db_engine = get_relational_engine() async with db_engine.get_async_session() as session: @@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]: ) )).all() + if datasets: + documnets_ids_in_dataset = set() + # If datasets are specified filter out documents that aren't part of the specified datasets + for dataset in datasets: + # Find dataset id for dataset element + dataset_id = (await session.scalars( + select(Dataset.id) + .where( + Dataset.name == dataset, + Dataset.owner_id == user_id, + ) + )).one() + + # Check which documents are connected to this dataset + for document_id in document_ids: + data_id = (await session.scalars( + select(DatasetData.data_id) + .where( + DatasetData.dataset_id == dataset_id, + DatasetData.data_id == document_id, + ) + )).one() + + # If document is related to dataset added it to return value + if data_id: + documnets_ids_in_dataset.add(document_id) + return list(documnets_ids_in_dataset) return document_ids