feat: Add search by dataset for cognee

Added ability to search by datasets for cognee users

Feature COG-912
This commit is contained in:
Igor Ilic 2024-12-17 11:20:22 +01:00
parent bfa0f06fb4
commit 630ab556db
3 changed files with 39 additions and 4 deletions

View file

@ -1,7 +1,7 @@
import json import json
from uuid import UUID from uuid import UUID
from enum import Enum from enum import Enum
from typing import Callable, Dict from typing import Callable, Dict, Union
from cognee.exceptions import InvalidValueError from cognee.exceptions import InvalidValueError
from cognee.modules.search.operations import log_query, log_result from cognee.modules.search.operations import log_query, log_result
@ -22,7 +22,12 @@ class SearchType(Enum):
CHUNKS = "CHUNKS" CHUNKS = "CHUNKS"
COMPLETION = "COMPLETION" COMPLETION = "COMPLETION"
async def search(query_type: SearchType, query_text: str, user: User = None) -> list: async def search(query_type: SearchType, query_text: str, user: User = None,
datasets: Union[list[str], str, None] = None) -> list:
# We use lists from now on for datasets
if isinstance(datasets, str):
datasets = [datasets]
if user is None: if user is None:
user = await get_default_user() user = await get_default_user()
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
query = await log_query(query_text, str(query_type), user.id) query = await log_query(query_text, str(query_type), user.id)
own_document_ids = await get_document_ids_for_user(user.id) own_document_ids = await get_document_ids_for_user(user.id, datasets)
search_results = await specific_search(query_type, query_text, user) search_results = await specific_search(query_type, query_text, user)
filtered_search_results = [] filtered_search_results = []

View file

@ -1,2 +1,3 @@
from .Data import Data from .Data import Data
from .Dataset import Dataset from .Dataset import Dataset
from .DatasetData import DatasetData

View file

@ -1,9 +1,11 @@
from uuid import UUID from uuid import UUID
from sqlalchemy import select from sqlalchemy import select
from cognee.infrastructure.databases.relational import get_relational_engine from cognee.infrastructure.databases.relational import get_relational_engine
from cognee.modules.data.models import Dataset, DatasetData
from ...models import ACL, Resource, Permission from ...models import ACL, Resource, Permission
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
db_engine = get_relational_engine() db_engine = get_relational_engine()
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
) )
)).all() )).all()
if datasets:
documnets_ids_in_dataset = set()
# If datasets are specified filter out documents that aren't part of the specified datasets
for dataset in datasets:
# Find dataset id for dataset element
dataset_id = (await session.scalars(
select(Dataset.id)
.where(
Dataset.name == dataset,
Dataset.owner_id == user_id,
)
)).one()
# Check which documents are connected to this dataset
for document_id in document_ids:
data_id = (await session.scalars(
select(DatasetData.data_id)
.where(
DatasetData.dataset_id == dataset_id,
DatasetData.data_id == document_id,
)
)).one()
# If document is related to dataset added it to return value
if data_id:
documnets_ids_in_dataset.add(document_id)
return list(documnets_ids_in_dataset)
return document_ids return document_ids