feat: Add search by dataset for cognee

Added ability to search by datasets for cognee users

Feature COG-912
This commit is contained in:
Igor Ilic 2024-12-17 11:20:22 +01:00
parent bfa0f06fb4
commit 630ab556db
3 changed files with 39 additions and 4 deletions

View file

@ -1,7 +1,7 @@
import json
from uuid import UUID
from enum import Enum
from typing import Callable, Dict
from typing import Callable, Dict, Union
from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result
@ -22,7 +22,12 @@ class SearchType(Enum):
CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION"
async def search(query_type: SearchType, query_text: str, user: User = None) -> list:
async def search(query_type: SearchType, query_text: str, user: User = None,
datasets: Union[list[str], str, None] = None) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
datasets = [datasets]
if user is None:
user = await get_default_user()
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id)
own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user)
filtered_search_results = []

View file

@ -1,2 +1,3 @@
from .Data import Data
from .Dataset import Dataset
from .DatasetData import DatasetData

View file

@ -1,9 +1,11 @@
from uuid import UUID
from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, DatasetData
from ...models import ACL, Resource, Permission
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
db_engine = get_relational_engine()
async with db_engine.get_async_session() as session:
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
)
)).all()
if datasets:
documnets_ids_in_dataset = set()
# If datasets are specified filter out documents that aren't part of the specified datasets
for dataset in datasets:
# Find dataset id for dataset element
dataset_id = (await session.scalars(
select(Dataset.id)
.where(
Dataset.name == dataset,
Dataset.owner_id == user_id,
)
)).one()
# Check which documents are connected to this dataset
for document_id in document_ids:
data_id = (await session.scalars(
select(DatasetData.data_id)
.where(
DatasetData.dataset_id == dataset_id,
DatasetData.data_id == document_id,
)
)).one()
# If document is related to dataset added it to return value
if data_id:
documnets_ids_in_dataset.add(document_id)
return list(documnets_ids_in_dataset)
return document_ids