fix: Resolve searching of dataset when you have permission but are not the owner
This commit is contained in:
parent
a1bf8416bd
commit
450320ba2c
3 changed files with 7 additions and 5 deletions
|
|
@ -4,7 +4,7 @@ from typing import Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||||
from cognee.modules.users.models import User
|
from cognee.modules.users.methods import get_user
|
||||||
|
|
||||||
# Note: ContextVar allows us to use different graph db configurations in Cognee
|
# Note: ContextVar allows us to use different graph db configurations in Cognee
|
||||||
# for different async tasks, threads and processes
|
# for different async tasks, threads and processes
|
||||||
|
|
@ -12,7 +12,7 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
||||||
graph_db_config = ContextVar("graph_db_config", default=None)
|
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||||
|
|
||||||
|
|
||||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user: User):
|
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||||
"""
|
"""
|
||||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||||
access to which will be enforced by given permissions.
|
access to which will be enforced by given permissions.
|
||||||
|
|
@ -25,7 +25,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: Cognee dataset name or id
|
dataset: Cognee dataset name or id
|
||||||
user: User object
|
user_id: UUID of user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
|
@ -34,6 +34,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user:
|
||||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
user = await get_user(user_id)
|
||||||
|
|
||||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ async def run_pipeline(
|
||||||
check_dataset_name(dataset.name)
|
check_dataset_name(dataset.name)
|
||||||
|
|
||||||
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
# Will only be used if ENABLE_BACKEND_ACCESS_CONTROL is set to True
|
||||||
await set_database_global_context_variables(dataset.name, user)
|
await set_database_global_context_variables(dataset.name, user.id)
|
||||||
|
|
||||||
# Ugly hack, but no easier way to do this.
|
# Ugly hack, but no easier way to do this.
|
||||||
if pipeline_name == "add_pipeline":
|
if pipeline_name == "add_pipeline":
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ async def specific_search_by_context(
|
||||||
|
|
||||||
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
|
async def _search_by_context(dataset, user, query_type, query_text, system_prompt_path, top_k):
|
||||||
# Set database configuration in async context for each dataset user has access for
|
# Set database configuration in async context for each dataset user has access for
|
||||||
await set_database_global_context_variables(dataset.id, user)
|
await set_database_global_context_variables(dataset.id, dataset.owner_id)
|
||||||
search_results = await specific_search(
|
search_results = await specific_search(
|
||||||
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
query_type, query_text, user, system_prompt_path=system_prompt_path, top_k=top_k
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue