refactor: Rename access control functions
This commit is contained in:
parent
53521c2068
commit
46c509778f
4 changed files with 14 additions and 17 deletions
|
|
@ -24,7 +24,7 @@ async def set_session_user_context_variable(user):
|
|||
session_user.set(user)
|
||||
|
||||
|
||||
def check_multi_user_support():
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
|
|
@ -33,24 +33,21 @@ def check_multi_user_support():
|
|||
)
|
||||
|
||||
|
||||
def check_backend_access_control_mode():
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return check_multi_user_support()
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = check_multi_user_support()
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
# If explicitly disabled, return false
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
|
|
@ -74,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not check_backend_access_control_mode():
|
||||
if not backend_access_control_enabled():
|
||||
return
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.context_global_variables import check_backend_access_control_mode
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -74,7 +74,7 @@ async def search(
|
|||
)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if check_backend_access_control_mode():
|
||||
if backend_access_control_enabled():
|
||||
search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -156,7 +156,7 @@ async def search(
|
|||
)
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if check_backend_access_control_mode():
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from ..models import User
|
|||
from ..get_fastapi_users import get_fastapi_users
|
||||
from .get_default_user import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.context_global_variables import check_backend_access_control_mode
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
|
||||
logger = get_logger("get_authenticated_user")
|
||||
|
|
@ -13,7 +13,7 @@ logger = get_logger("get_authenticated_user")
|
|||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or check_backend_access_control_mode()
|
||||
or backend_access_control_enabled()
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -147,9 +147,9 @@ async def main():
|
|||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
|
||||
from cognee.context_global_variables import check_backend_access_control_mode
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if check_backend_access_control_mode():
|
||||
if backend_access_control_enabled():
|
||||
text = search_results[0]["search_result"][0]
|
||||
else:
|
||||
text = search_results[0]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue