refactor: Rename access control functions

This commit is contained in:
Igor Ilic 2025-11-04 12:06:16 +01:00
parent 53521c2068
commit 46c509778f
4 changed files with 14 additions and 17 deletions

View file

@ -24,7 +24,7 @@ async def set_session_user_context_variable(user):
session_user.set(user) session_user.set(user)
def check_multi_user_support(): def multi_user_support_possible():
graph_db_config = get_graph_context_config() graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config() vector_db_config = get_vectordb_context_config()
return ( return (
@ -33,24 +33,21 @@ def check_multi_user_support():
) )
def check_backend_access_control_mode(): def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None: if backend_access_control is None:
# If backend access control is not defined in environment variables, # If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it # enable it by default if graph and vector DBs can support it, otherwise disable it
return check_multi_user_support() return multi_user_support_possible()
elif backend_access_control.lower() == "true": elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it # If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = check_multi_user_support() multi_user_support = multi_user_support_possible()
if not multi_user_support: if not multi_user_support:
raise EnvironmentError( raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
) )
else: return True
return True return False
else:
# If explicitly disabled, return false
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
@ -74,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config() base_config = get_base_config()
if not check_backend_access_control_mode(): if not backend_access_control_enabled():
return return
user = await get_user(user_id) user = await get_user(user_id)

View file

@ -8,7 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import check_backend_access_control_mode from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +74,7 @@ async def search(
) )
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if check_backend_access_control_mode(): if backend_access_control_enabled():
search_results = await authorized_search( search_results = await authorized_search(
query_type=query_type, query_type=query_type,
query_text=query_text, query_text=query_text,
@ -156,7 +156,7 @@ async def search(
) )
else: else:
# This is for maintaining backwards compatibility # This is for maintaining backwards compatibility
if check_backend_access_control_mode(): if backend_access_control_enabled():
return_value = [] return_value = []
for search_result in search_results: for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result) prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,7 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import check_backend_access_control_mode from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user") logger = get_logger("get_authenticated_user")
@ -13,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement # Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = ( REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or check_backend_access_control_mode() or backend_access_control_enabled()
) )
fastapi_users = get_fastapi_users() fastapi_users = get_fastapi_users()

View file

@ -147,9 +147,9 @@ async def main():
f"{name}: expected single-element list, got {len(search_results)}" f"{name}: expected single-element list, got {len(search_results)}"
) )
from cognee.context_global_variables import check_backend_access_control_mode from cognee.context_global_variables import backend_access_control_enabled
if check_backend_access_control_mode(): if backend_access_control_enabled():
text = search_results[0]["search_result"][0] text = search_results[0]["search_result"][0]
else: else:
text = search_results[0] text = search_results[0]