fix: Resolve issue with searching datasets by UUID
This commit is contained in:
parent
8e85cd2ae3
commit
7f3aebd06d
3 changed files with 12 additions and 4 deletions
|
|
@ -11,7 +11,7 @@ from cognee.shared.data_models import KnowledgeGraph
|
|||
|
||||
class CognifyPayloadDTO(BaseModel):
|
||||
datasets: List[str]
|
||||
dataset_ids: Optional[List[UUID]]
|
||||
dataset_ids: Optional[List[UUID]] = None
|
||||
graph_model: Optional[BaseModel] = KnowledgeGraph
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ from cognee.modules.search.operations import get_history
|
|||
from cognee.modules.users.methods import get_authenticated_user
|
||||
|
||||
|
||||
# Note: Datasets sent by name will only map to datasets owned by the request sender
|
||||
# To search for datasets not owned by the request sender dataset UUID is needed
|
||||
class SearchPayloadDTO(InDTO):
|
||||
search_type: SearchType
|
||||
datasets: Optional[Union[list[UUID], list[str]]] = None
|
||||
datasets: Optional[list[str]] = None
|
||||
dataset_ids: Optional[list[UUID]] = None
|
||||
query: str
|
||||
|
||||
|
||||
|
|
@ -45,6 +48,7 @@ def get_search_router() -> APIRouter:
|
|||
query_type=payload.search_type,
|
||||
user=user,
|
||||
datasets=payload.datasets,
|
||||
dataset_ids=payload.dataset_ids,
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -6,13 +6,15 @@ from cognee.modules.search.types import SearchType
|
|||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.search.methods import search as search_function
|
||||
from cognee.modules.data.methods import get_authorized_existing_datasets
|
||||
from cognee.modules.data.exceptions import DatasetNotFoundError
|
||||
|
||||
|
||||
async def search(
|
||||
query_text: str,
|
||||
query_type: SearchType = SearchType.GRAPH_COMPLETION,
|
||||
user: User = None,
|
||||
datasets: Union[list[UUID], list[str], str, UUID, None] = None,
|
||||
datasets: Optional[Union[list[str], str]] = None,
|
||||
dataset_ids: Optional[Union[list[UUID], UUID]] = None,
|
||||
system_prompt_path: str = "answer_simple_question.txt",
|
||||
top_k: int = 10,
|
||||
node_type: Optional[Type] = None,
|
||||
|
|
@ -29,11 +31,13 @@ async def search(
|
|||
if datasets is not None and [all(isinstance(dataset, str) for dataset in datasets)]:
|
||||
datasets = await get_authorized_existing_datasets(datasets, "read", user)
|
||||
datasets = [dataset.id for dataset in datasets]
|
||||
if not datasets:
|
||||
raise DatasetNotFoundError(message="No datasets found.")
|
||||
|
||||
filtered_search_results = await search_function(
|
||||
query_text=query_text,
|
||||
query_type=query_type,
|
||||
dataset_ids=datasets,
|
||||
dataset_ids=dataset_ids if dataset_ids else datasets,
|
||||
user=user,
|
||||
system_prompt_path=system_prompt_path,
|
||||
top_k=top_k,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue