feat: Add search by dataset for cognee
Added ability to search by datasets for cognee users Feature COG-912
This commit is contained in:
parent
bfa0f06fb4
commit
630ab556db
3 changed files with 39 additions and 4 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict, Union
|
||||||
|
|
||||||
from cognee.exceptions import InvalidValueError
|
from cognee.exceptions import InvalidValueError
|
||||||
from cognee.modules.search.operations import log_query, log_result
|
from cognee.modules.search.operations import log_query, log_result
|
||||||
|
|
@ -22,7 +22,12 @@ class SearchType(Enum):
|
||||||
CHUNKS = "CHUNKS"
|
CHUNKS = "CHUNKS"
|
||||||
COMPLETION = "COMPLETION"
|
COMPLETION = "COMPLETION"
|
||||||
|
|
||||||
async def search(query_type: SearchType, query_text: str, user: User = None) -> list:
|
async def search(query_type: SearchType, query_text: str, user: User = None,
|
||||||
|
datasets: Union[list[str], str, None] = None) -> list:
|
||||||
|
# We use lists from now on for datasets
|
||||||
|
if isinstance(datasets, str):
|
||||||
|
datasets = [datasets]
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await get_default_user()
|
user = await get_default_user()
|
||||||
|
|
||||||
|
|
@ -31,7 +36,7 @@ async def search(query_type: SearchType, query_text: str, user: User = None) ->
|
||||||
|
|
||||||
query = await log_query(query_text, str(query_type), user.id)
|
query = await log_query(query_text, str(query_type), user.id)
|
||||||
|
|
||||||
own_document_ids = await get_document_ids_for_user(user.id)
|
own_document_ids = await get_document_ids_for_user(user.id, datasets)
|
||||||
search_results = await specific_search(query_type, query_text, user)
|
search_results = await specific_search(query_type, query_text, user)
|
||||||
|
|
||||||
filtered_search_results = []
|
filtered_search_results = []
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from .Data import Data
|
from .Data import Data
|
||||||
from .Dataset import Dataset
|
from .Dataset import Dataset
|
||||||
|
from .DatasetData import DatasetData
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from cognee.infrastructure.databases.relational import get_relational_engine
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
||||||
|
from cognee.modules.data.models import Dataset, DatasetData
|
||||||
from ...models import ACL, Resource, Permission
|
from ...models import ACL, Resource, Permission
|
||||||
|
|
||||||
async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
|
||||||
|
async def get_document_ids_for_user(user_id: UUID, datasets: list[str] = None) -> list[str]:
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
|
|
@ -18,4 +20,31 @@ async def get_document_ids_for_user(user_id: UUID) -> list[str]:
|
||||||
)
|
)
|
||||||
)).all()
|
)).all()
|
||||||
|
|
||||||
|
if datasets:
|
||||||
|
documnets_ids_in_dataset = set()
|
||||||
|
# If datasets are specified filter out documents that aren't part of the specified datasets
|
||||||
|
for dataset in datasets:
|
||||||
|
# Find dataset id for dataset element
|
||||||
|
dataset_id = (await session.scalars(
|
||||||
|
select(Dataset.id)
|
||||||
|
.where(
|
||||||
|
Dataset.name == dataset,
|
||||||
|
Dataset.owner_id == user_id,
|
||||||
|
)
|
||||||
|
)).one()
|
||||||
|
|
||||||
|
# Check which documents are connected to this dataset
|
||||||
|
for document_id in document_ids:
|
||||||
|
data_id = (await session.scalars(
|
||||||
|
select(DatasetData.data_id)
|
||||||
|
.where(
|
||||||
|
DatasetData.dataset_id == dataset_id,
|
||||||
|
DatasetData.data_id == document_id,
|
||||||
|
)
|
||||||
|
)).one()
|
||||||
|
|
||||||
|
# If document is related to dataset added it to return value
|
||||||
|
if data_id:
|
||||||
|
documnets_ids_in_dataset.add(document_id)
|
||||||
|
return list(documnets_ids_in_dataset)
|
||||||
return document_ids
|
return document_ids
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue