cognee/cognee/modules/search/methods/search.py
Igor Ilic 7b5bba2b18
refactor: Unify dataset resolution (#1488)
<!-- .github/pull_request_template.md -->

## Description
Unified dataset resolution mechanisms across cognee

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-10-07 19:14:46 +02:00

371 lines
13 KiB
Python

import os
import json
import asyncio
from uuid import UUID
from fastapi.encoders import jsonable_encoder
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.search.types import (
SearchResult,
CombinedSearchResult,
SearchResultDataset,
SearchType,
)
from cognee.modules.search.operations import log_query, log_result
from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
from cognee.modules.data.methods.get_authorized_existing_datasets import (
get_authorized_existing_datasets,
)
from .get_search_type_tools import get_search_type_tools
from .no_access_control_search import no_access_control_search
from ..utils.prepare_search_result import prepare_search_result
async def search(
query_text: str,
query_type: SearchType,
dataset_ids: Union[list[UUID], None],
user: User,
system_prompt_path="answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
) -> Union[CombinedSearchResult, List[SearchResult]]:
"""
Args:
query_text:
query_type:
datasets:
user:
system_prompt_path:
top_k:
Returns:
Notes:
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
query = await log_query(query_text, query_type.value, user.id)
send_telemetry("cognee.search EXECUTION STARTED", user.id)
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
user=user,
dataset_ids=dataset_ids,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
use_combined_context=use_combined_context,
)
else:
search_results = [
await no_access_control_search(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
]
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
await log_result(
query.id,
json.dumps(
jsonable_encoder(
await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
if use_combined_context
else [
await prepare_search_result(search_result) for search_result in search_results
]
)
),
user.id,
)
if use_combined_context:
prepared_search_results = await prepare_search_result(
search_results[0] if isinstance(search_results, list) else search_results
)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
return CombinedSearchResult(
result=result,
graphs=graphs,
context=context,
datasets=[
SearchResultDataset(
id=dataset.id,
name=dataset.name,
)
for dataset in datasets
],
)
else:
# This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
result = prepared_search_results["result"]
graphs = prepared_search_results["graphs"]
context = prepared_search_results["context"]
datasets = prepared_search_results["datasets"]
if only_context:
return_value.append(
{
"search_result": [context] if context else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"graphs": graphs,
}
)
else:
return_value.append(
{
"search_result": [result] if result else None,
"dataset_id": datasets[0].id,
"dataset_name": datasets[0].name,
"graphs": graphs,
}
)
return return_value
else:
return_value = []
if only_context:
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)
return_value.append(prepared_search_results["context"])
else:
for search_result in search_results:
result, context, datasets = search_result
return_value.append(result)
# For maintaining backwards compatibility
if len(return_value) == 1 and isinstance(return_value[0], list):
return return_value[0]
else:
return return_value
async def authorized_search(
query_type: SearchType,
query_text: str,
user: User,
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
use_combined_context: bool = False,
) -> Union[
Tuple[Any, Union[List[Edge], str], List[Dataset]],
List[Tuple[Any, Union[List[Edge], str], List[Dataset]]],
]:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode.
"""
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
search_datasets = await get_authorized_existing_datasets(
datasets=dataset_ids, permission_type="read", user=user
)
if use_combined_context:
search_responses = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=True,
)
context = {}
datasets: List[Dataset] = []
for _, search_context, search_datasets in search_responses:
for dataset in search_datasets:
context[str(dataset.id)] = search_context
datasets.extend(search_datasets)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, _] = search_tools
else:
get_completion = search_tools[0]
def prepare_combined_context(
context,
) -> Union[List[Edge], str]:
combined_context = []
for dataset_context in context.values():
combined_context += dataset_context
if combined_context and isinstance(combined_context[0], str):
return "\n".join(combined_context)
return combined_context
combined_context = prepare_combined_context(context)
completion = await get_completion(query_text, combined_context)
return completion, combined_context, datasets
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await search_in_datasets_context(
search_datasets=search_datasets,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
)
return search_results
async def search_in_datasets_context(
search_datasets: list[Dataset],
query_type: SearchType,
query_text: str,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
) -> List[Tuple[Any, Union[str, List[Edge]], List[Dataset]]]:
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
Not to be used outside of active access control mode.
"""
async def _search_in_dataset_context(
dataset: Dataset,
query_type: SearchType,
query_text: str,
system_prompt_path: str = "answer_simple_question.txt",
system_prompt: Optional[str] = None,
top_k: int = 10,
node_type: Optional[Type] = NodeSet,
node_name: Optional[List[str]] = None,
save_interaction: bool = False,
last_k: Optional[int] = None,
only_context: bool = False,
context: Optional[Any] = None,
) -> Tuple[Any, Union[str, List[Edge]], List[Dataset]]:
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
specific_search_tools = await get_search_type_tools(
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
)
search_tools = specific_search_tools
if len(search_tools) == 2:
[get_completion, get_context] = search_tools
if only_context:
return None, await get_context(query_text), [dataset]
search_context = context or await get_context(query_text)
search_result = await get_completion(query_text, search_context)
return search_result, search_context, [dataset]
else:
unknown_tool = search_tools[0]
return await unknown_tool(query_text), "", [dataset]
# Search every dataset async based on query and appropriate database configuration
tasks = []
for dataset in search_datasets:
tasks.append(
_search_in_dataset_context(
dataset=dataset,
query_type=query_type,
query_text=query_text,
system_prompt_path=system_prompt_path,
system_prompt=system_prompt,
top_k=top_k,
node_type=node_type,
node_name=node_name,
save_interaction=save_interaction,
last_k=last_k,
only_context=only_context,
context=context,
)
)
return await asyncio.gather(*tasks)