feat: Add search by dataset for cognee
Added ability to search by datasets for cognee users Feature COG-912
This commit is contained in:
parent
bfa0f06fb4
commit
630ab556db
3 changed files with 39 additions and 4 deletions
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
from cognee.exceptions import InvalidValueError
|
||||
from cognee.modules.search.operations import log_query, log_result
|
||||
|
|
@ -22,7 +22,12 @@ class SearchType(Enum):
|
|||
CHUNKS = "CHUNKS"
|
||||
COMPLETION = "COMPLETION"
|
||||
|
||||
async def search(query_type: SearchType, query_text: str, user: User = None) -> list:
|
||||
async def search(query_type: SearchType, query_text: str, user: User = None,
|
||||
datasets: Union[list[str], str, None] = None) -> list:
|
||||
# We use lists from now on for datasets
|
||||
if isinstance(datasets, str):
|
||||
datasets = [datasets]
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
|
|||
|
||||
query = await log_query(query_text, str(query_type), user.id)
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||
search_results = await specific_search(query_type, query_text, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from .Data import Data
|
||||
from .Dataset import Dataset
|
||||
from .DatasetData import DatasetData
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||
from cognee.modules.data.models import Dataset, DatasetData
|
||||
from ...models import ACL, Resource, Permission
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
||||
|
||||
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
|
||||
db_engine = get_relational_engine()
|
||||
|
||||
async with db_engine.get_async_session() as session:
|
||||
|
|
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
|||
)
|
||||
)).all()
|
||||
|
||||
if datasets:
|
||||
documnets_ids_in_dataset = set()
|
||||
# If datasets are specified filter out documents that aren't part of the specified datasets
|
||||
for dataset in datasets:
|
||||
# Find dataset id for dataset element
|
||||
dataset_id = (await session.scalars(
|
||||
select(Dataset.id)
|
||||
.where(
|
||||
Dataset.name == dataset,
|
||||
Dataset.owner_id == user_id,
|
||||
)
|
||||
)).one()
|
||||
|
||||
# Check which documents are connected to this dataset
|
||||
for document_id in document_ids:
|
||||
data_id = (await session.scalars(
|
||||
select(DatasetData.data_id)
|
||||
.where(
|
||||
DatasetData.dataset_id == dataset_id,
|
||||
DatasetData.data_id == document_id,
|
||||
)
|
||||
)).one()
|
||||
|
||||
# If document is related to dataset added it to return value
|
||||
if data_id:
|
||||
documnets_ids_in_dataset.add(document_id)
|
||||
return list(documnets_ids_in_dataset)
|
||||
return document_ids
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue