refactor: Rename access control functions
This commit is contained in:
parent
53521c2068
commit
46c509778f
4 changed files with 14 additions and 17 deletions
|
|
@ -24,7 +24,7 @@ async def set_session_user_context_variable(user):
|
||||||
session_user.set(user)
|
session_user.set(user)
|
||||||
|
|
||||||
|
|
||||||
def check_multi_user_support():
|
def multi_user_support_possible():
|
||||||
graph_db_config = get_graph_context_config()
|
graph_db_config = get_graph_context_config()
|
||||||
vector_db_config = get_vectordb_context_config()
|
vector_db_config = get_vectordb_context_config()
|
||||||
return (
|
return (
|
||||||
|
|
@ -33,24 +33,21 @@ def check_multi_user_support():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_backend_access_control_mode():
|
def backend_access_control_enabled():
|
||||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||||
if backend_access_control is None:
|
if backend_access_control is None:
|
||||||
# If backend access control is not defined in environment variables,
|
# If backend access control is not defined in environment variables,
|
||||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||||
return check_multi_user_support()
|
return multi_user_support_possible()
|
||||||
elif backend_access_control.lower() == "true":
|
elif backend_access_control.lower() == "true":
|
||||||
# If enabled, ensure that the current graph and vector DBs can support it
|
# If enabled, ensure that the current graph and vector DBs can support it
|
||||||
multi_user_support = check_multi_user_support()
|
multi_user_support = multi_user_support_possible()
|
||||||
if not multi_user_support:
|
if not multi_user_support:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||||
)
|
)
|
||||||
else:
|
return True
|
||||||
return True
|
return False
|
||||||
else:
|
|
||||||
# If explicitly disabled, return false
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||||
|
|
@ -74,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
|
|
||||||
base_config = get_base_config()
|
base_config = get_base_config()
|
||||||
|
|
||||||
if not check_backend_access_control_mode():
|
if not backend_access_control_enabled():
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await get_user(user_id)
|
user = await get_user(user_id)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.shared.utils import send_telemetry
|
from cognee.shared.utils import send_telemetry
|
||||||
from cognee.context_global_variables import set_database_global_context_variables
|
from cognee.context_global_variables import set_database_global_context_variables
|
||||||
from cognee.context_global_variables import check_backend_access_control_mode
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
from cognee.modules.engine.models.node_set import NodeSet
|
from cognee.modules.engine.models.node_set import NodeSet
|
||||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||||
|
|
@ -74,7 +74,7 @@ async def search(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use search function filtered by permissions if access control is enabled
|
# Use search function filtered by permissions if access control is enabled
|
||||||
if check_backend_access_control_mode():
|
if backend_access_control_enabled():
|
||||||
search_results = await authorized_search(
|
search_results = await authorized_search(
|
||||||
query_type=query_type,
|
query_type=query_type,
|
||||||
query_text=query_text,
|
query_text=query_text,
|
||||||
|
|
@ -156,7 +156,7 @@ async def search(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# This is for maintaining backwards compatibility
|
# This is for maintaining backwards compatibility
|
||||||
if check_backend_access_control_mode():
|
if backend_access_control_enabled():
|
||||||
return_value = []
|
return_value = []
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
prepared_search_results = await prepare_search_result(search_result)
|
prepared_search_results = await prepare_search_result(search_result)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from ..models import User
|
||||||
from ..get_fastapi_users import get_fastapi_users
|
from ..get_fastapi_users import get_fastapi_users
|
||||||
from .get_default_user import get_default_user
|
from .get_default_user import get_default_user
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
from cognee.context_global_variables import check_backend_access_control_mode
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("get_authenticated_user")
|
logger = get_logger("get_authenticated_user")
|
||||||
|
|
@ -13,7 +13,7 @@ logger = get_logger("get_authenticated_user")
|
||||||
# Check environment variable to determine authentication requirement
|
# Check environment variable to determine authentication requirement
|
||||||
REQUIRE_AUTHENTICATION = (
|
REQUIRE_AUTHENTICATION = (
|
||||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||||
or check_backend_access_control_mode()
|
or backend_access_control_enabled()
|
||||||
)
|
)
|
||||||
|
|
||||||
fastapi_users = get_fastapi_users()
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
|
||||||
|
|
@ -147,9 +147,9 @@ async def main():
|
||||||
f"{name}: expected single-element list, got {len(search_results)}"
|
f"{name}: expected single-element list, got {len(search_results)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.context_global_variables import check_backend_access_control_mode
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
if check_backend_access_control_mode():
|
if backend_access_control_enabled():
|
||||||
text = search_results[0]["search_result"][0]
|
text = search_results[0]["search_result"][0]
|
||||||
else:
|
else:
|
||||||
text = search_results[0]
|
text = search_results[0]
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue