cognee/cognee/modules/search/methods/search.py

212 lines
7.6 KiB
Python

import os
import json
import asyncio
from uuid import UUID
from typing import Callable, List, Optional, Type, Union
from cognee.context_global_variables import set_database_global_context_variables
from cognee.exceptions import InvalidValueError
from cognee.modules.retrieval.chunks_retriever import ChunksRetriever
from cognee.modules.retrieval.insights_retriever import InsightsRetriever
from cognee.modules.retrieval.summaries_retriever import SummariesRetriever
from cognee.modules.retrieval.completion_retriever import CompletionRetriever
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
from cognee.modules.retrieval.graph_summary_completion_retriever import (
GraphSummaryCompletionRetriever,
)
from cognee.modules.retrieval.graph_completion_cot_retriever import GraphCompletionCotRetriever
from cognee.modules.retrieval.graph_completion_context_extension_retriever import (
GraphCompletionContextExtensionRetriever,
)
from cognee.modules.retrieval.code_retriever import CodeRetriever
from cognee.modules.retrieval.cypher_search_retriever import CypherSearchRetriever
from cognee.modules.retrieval.natural_language_retriever import NaturalLanguageRetriever
from cognee.modules.retrieval.user_qa_feedback import UserQAFeedback
from cognee.modules.search.types import SearchType
from cognee.modules.storage.utils import JSONEncoder
from cognee.modules.users.models import User
from cognee.modules.data.models import Dataset
from cognee.shared.utils import send_telemetry
from cognee.modules.users.permissions.methods import get_specific_user_permission_datasets
from cognee.modules.search.operations import log_query, log_result
async def search(
query_text: str,
query_type: SearchType,
dataset_ids: Union[list[UUID], None],
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
):
"""
Args:
query_text:
query_type:
datasets:
user:
system_prompt_path:
top_k:
last_k:
Returns:
Notes:
Searching by dataset is only available in ENABLE_BACKEND_ACCESS_CONTROL mode
"""
# Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
return await authorized_search(
query_text, query_type, user, dataset_ids, system_prompt_path, top_k
)
query = await log_query(query_text, query_type.value, user.id)
search_results = await specific_search(
query_type,
query_text,
user,
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
last_k=last_k,
)
await log_result(
query.id,
json.dumps(
search_results if len(search_results) > 1 else search_results[0], cls=JSONEncoder
),
user.id,
)
return search_results
async def specific_search(
query_type: SearchType,
query: str,
user: User,
system_prompt_path="answer_simple_question.txt",
top_k: int = 10,
node_type: Optional[Type] = None,
node_name: Optional[List[str]] = None,
last_k: Optional[int] = None,
) -> list:
search_tasks: dict[SearchType, Callable] = {
SearchType.SUMMARIES: SummariesRetriever(top_k=top_k).get_completion,
SearchType.INSIGHTS: InsightsRetriever(top_k=top_k).get_completion,
SearchType.CHUNKS: ChunksRetriever(top_k=top_k).get_completion,
SearchType.RAG_COMPLETION: CompletionRetriever(
system_prompt_path=system_prompt_path, top_k=top_k
).get_completion,
SearchType.GRAPH_COMPLETION: GraphCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_COMPLETION_COT: GraphCompletionCotRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_COMPLETION_CONTEXT_EXTENSION: GraphCompletionContextExtensionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.GRAPH_SUMMARY_COMPLETION: GraphSummaryCompletionRetriever(
system_prompt_path=system_prompt_path,
top_k=top_k,
node_type=node_type,
node_name=node_name,
).get_completion,
SearchType.CODE: CodeRetriever(top_k=top_k).get_completion,
SearchType.CYPHER: CypherSearchRetriever().get_completion,
SearchType.NATURAL_LANGUAGE: NaturalLanguageRetriever().get_completion,
SearchType.FEEDBACK: UserQAFeedback(last_k=last_k).add_feedback,
}
search_task = search_tasks.get(query_type)
if search_task is None:
raise InvalidValueError(message=f"Unsupported search type: {query_type}")
send_telemetry("cognee.search EXECUTION STARTED", user.id)
results = await search_task(query)
send_telemetry("cognee.search EXECUTION COMPLETED", user.id)
return results
async def authorized_search(
query_text: str,
query_type: SearchType,
user: User = None,
dataset_ids: Optional[list[UUID]] = None,
system_prompt_path: str = "answer_simple_question.txt",
top_k: int = 10,
) -> list:
"""
Verifies access for provided datasets or uses all datasets user has read access for and performs search per dataset.
Not to be used outside of active access control mode.
"""
query = await log_query(query_text, query_type.value, user.id)
# Find datasets user has read access for (if datasets are provided only return them. Provided user has read access)
search_datasets = await get_specific_user_permission_datasets(user.id, "read", dataset_ids)
# Searches all provided datasets and handles setting up of appropriate database context based on permissions
search_results = await specific_search_by_context(
search_datasets, query_text, query_type, user, system_prompt_path, top_k
)
await log_result(query.id, json.dumps(search_results, cls=JSONEncoder), user.id)
return search_results
async def specific_search_by_context(
search_datasets: list[Dataset],
query_text: str,
query_type: SearchType,
user: User,
system_prompt_path: str,
top_k: int,
):
"""
Searches all provided datasets and handles setting up of appropriate database context based on permissions.
Not to be used outside of active access control mode.
"""
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
# Set database configuration in async context for each dataset user has access for
await set_database_global_context_variables(dataset.id, dataset.owner_id)
search_results = await specific_search(
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
)
return {
"search_result": search_results,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
}
# Search every dataset async based on query and appropriate database configuration
tasks = []
for dataset in search_datasets:
tasks.append(
_search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k)
)
return await asyncio.gather(*tasks)