refactor: Rename access control functions

This commit is contained in:
Igor Ilic 2025-11-04 12:06:16 +01:00
parent 53521c2068
commit 46c509778f
4 changed files with 14 additions and 17 deletions

View file

@ -24,7 +24,7 @@ async def set_session_user_context_variable(user):
session_user.set(user)
def check_multi_user_support():
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
@ -33,24 +33,21 @@ def check_multi_user_support():
)
def check_backend_access_control_mode():
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return check_multi_user_support()
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = check_multi_user_support()
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
else:
return True
else:
# If explicitly disabled, return false
return False
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
@ -74,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config()
if not check_backend_access_control_mode():
if not backend_access_control_enabled():
return
user = await get_user(user_id)

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import check_backend_access_control_mode
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +74,7 @@ async def search(
)
# Use search function filtered by permissions if access control is enabled
if check_backend_access_control_mode():
if backend_access_control_enabled():
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
@ -156,7 +156,7 @@ async def search(
)
else:
# This is for maintaining backwards compatibility
if check_backend_access_control_mode():
if backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,7 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import check_backend_access_control_mode
from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user")
@ -13,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or check_backend_access_control_mode()
or backend_access_control_enabled()
)
fastapi_users = get_fastapi_users()

View file

@ -147,9 +147,9 @@ async def main():
f"{name}: expected single-element list, got {len(search_results)}"
)
from cognee.context_global_variables import check_backend_access_control_mode
from cognee.context_global_variables import backend_access_control_enabled
if check_backend_access_control_mode():
if backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]