fix: move access control check to infra

This commit is contained in:
Boris Arzentar 2025-11-07 16:25:42 +01:00
parent a081f8ba43
commit 2e234e1a87
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
12 changed files with 95 additions and 74 deletions

View file

@ -56,8 +56,8 @@ def upgrade() -> None:
sa.Column("dataset_id", sa.UUID, index=True), sa.Column("dataset_id", sa.UUID, index=True),
sa.Column("source_node_id", sa.UUID, nullable=False), sa.Column("source_node_id", sa.UUID, nullable=False),
sa.Column("destination_node_id", sa.UUID, nullable=False), sa.Column("destination_node_id", sa.UUID, nullable=False),
sa.Column("label", sa.String()), sa.Column("label", sa.Text()),
sa.Column("relationship_name", sa.String(), nullable=False), sa.Column("relationship_name", sa.Text(), nullable=False),
sa.Column("props", sa.JSON()), sa.Column("props", sa.JSON()),
sa.Column( sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)

View file

@ -66,7 +66,7 @@ class datasets:
if not dataset: if not dataset:
raise UnauthorizedDataAccessError(f"Dataset {dataset_id} not accessible.") raise UnauthorizedDataAccessError(f"Dataset {dataset_id} not accessible.")
await delete_dataset_nodes_and_edges(dataset_id) await delete_dataset_nodes_and_edges(dataset_id, user.id)
return await delete_dataset(dataset) return await delete_dataset(dataset)

View file

@ -3,12 +3,11 @@ from contextvars import ContextVar
from typing import Union from typing import Union
from uuid import UUID from uuid import UUID
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
# Note: ContextVar allows us to use different graph db configurations in Cognee # Note: ContextVar allows us to use different graph db configurations in Cognee
# for different async tasks, threads and processes # for different async tasks, threads and processes
@ -16,40 +15,11 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None) session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
async def set_session_user_context_variable(user): async def set_session_user_context_variable(user):
session_user.set(user) session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
""" """
If backend access control is enabled this function will ensure all datasets have their own databases, If backend access control is enabled this function will ensure all datasets have their own databases,
@ -71,11 +41,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config() base_config = get_base_config()
if not backend_access_control_enabled(): if not is_backend_access_control_enabled():
return return
from cognee.modules.users.methods.get_user import get_user
user = await get_user(user_id) user = await get_user(user_id)
from cognee.infrastructure.databases.utils.get_or_create_dataset_database import (
get_or_create_dataset_database,
)
# To ensure permissions are enforced properly all datasets will have their own databases # To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user) dataset_database = await get_or_create_dataset_database(dataset, user)

View file

@ -0,0 +1,34 @@
import os
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def is_backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False

View file

@ -1,7 +1,7 @@
from uuid import UUID from uuid import UUID
from typing import Dict, List from typing import Dict, List
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger

View file

@ -3,7 +3,8 @@ from typing import Dict, List
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
from cognee.modules.graph.methods import ( from cognee.modules.graph.methods import (
delete_dataset_related_edges, delete_dataset_related_edges,
delete_dataset_related_nodes, delete_dataset_related_nodes,
@ -12,17 +13,26 @@ from cognee.modules.graph.methods import (
) )
async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None: async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None:
affected_nodes = await get_dataset_related_nodes(dataset_id) affected_nodes = await get_dataset_related_nodes(dataset_id)
if len(affected_nodes) == 0: if len(affected_nodes) == 0:
return return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
affected_relationships = await get_dataset_related_edges(dataset_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
]
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes]) await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
affected_vector_collections: Dict[str, List] = {} affected_vector_collections: Dict[str, List] = {}
for node in affected_nodes: for node in non_legacy_nodes:
for indexed_field in node.indexed_fields: for indexed_field in node.indexed_fields:
collection_name = f"{node.type}_{indexed_field}" collection_name = f"{node.type}_{indexed_field}"
if collection_name not in affected_vector_collections: if collection_name not in affected_vector_collections:
@ -30,17 +40,22 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None:
affected_vector_collections[collection_name].append(node) affected_vector_collections[collection_name].append(node)
vector_engine = get_vector_engine() vector_engine = get_vector_engine()
for affected_collection, affected_nodes in affected_vector_collections.items(): for affected_collection, non_legacy_nodes in affected_vector_collections.items():
await vector_engine.delete_data_points( await vector_engine.delete_data_points(
affected_collection, [node.id for node in affected_nodes] affected_collection, [node.id for node in non_legacy_nodes]
) )
affected_relationships = await get_dataset_related_edges(dataset_id) if len(affected_relationships) > 0:
non_legacy_relationships = [
edge
for index, edge in enumerate(affected_relationships)
if not is_legacy_relationship[index]
]
await vector_engine.delete_data_points( await vector_engine.delete_data_points(
"EdgeType_relationship_name", "EdgeType_relationship_name",
[generate_edge_id(edge.relationship_name) for edge in affected_relationships], [str(relationship.slug) for relationship in non_legacy_relationships],
) )
await delete_dataset_related_nodes(dataset_id) await delete_dataset_related_nodes(dataset_id)
await delete_dataset_related_edges(dataset_id) await delete_dataset_related_edges(dataset_id)

View file

@ -5,9 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from cognee.modules.engine.utils import generate_edge_id from cognee.modules.engine.utils import generate_edge_id
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models.Edge import Edge from cognee.modules.graph.models.Edge import Edge
from .set_current_user import set_current_user from cognee.infrastructure.databases.relational.with_async_session import with_async_session
@with_async_session @with_async_session
@ -25,10 +24,6 @@ async def upsert_edges(
----------- -----------
- edges (list): A list of edges to be added to the graph. - edges (list): A list of edges to be added to the graph.
""" """
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
edges_to_add = [] edges_to_add = []
for edge in edges: for edge in edges:

View file

@ -3,10 +3,9 @@ from uuid import NAMESPACE_OID, UUID, uuid5
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node from cognee.modules.graph.models import Node
from .set_current_user import set_current_user from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational.with_async_session import with_async_session
@with_async_session @with_async_session
@ -20,10 +19,6 @@ async def upsert_nodes(
----------- -----------
- nodes (list): A list of nodes to be added to the graph. - nodes (list): A list of nodes to be added to the graph.
""" """
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
upsert_statement = ( upsert_statement = (
insert(Node) insert(Node)
.values( .values(

View file

@ -2,9 +2,9 @@ from datetime import datetime, timezone
from sqlalchemy import ( from sqlalchemy import (
# event, # event,
DateTime, DateTime,
String,
JSON, JSON,
UUID, UUID,
Text,
) )
# from sqlalchemy.schema import DDL # from sqlalchemy.schema import DDL
@ -29,9 +29,9 @@ class Edge(Base):
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
relationship_name: Mapped[str | None] = mapped_column(String(255), nullable=False) relationship_name: Mapped[str | None] = mapped_column(Text, nullable=False)
label: Mapped[str | None] = mapped_column(String(255)) label: Mapped[str | None] = mapped_column(Text)
attributes: Mapped[dict | None] = mapped_column(JSON) attributes: Mapped[dict | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(

View file

@ -4,11 +4,13 @@ from uuid import UUID
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from typing import Any, List, Optional, Tuple, Type, Union from typing import Any, List, Optional, Tuple, Type, Union
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +76,7 @@ async def search(
) )
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if backend_access_control_enabled(): if is_backend_access_control_enabled():
search_results = await authorized_search( search_results = await authorized_search(
query_type=query_type, query_type=query_type,
query_text=query_text, query_text=query_text,
@ -156,7 +158,7 @@ async def search(
) )
else: else:
# This is for maintaining backwards compatibility # This is for maintaining backwards compatibility
if backend_access_control_enabled(): if is_backend_access_control_enabled():
return_value = [] return_value = []
for search_result in search_results: for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result) prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,7 +5,9 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import backend_access_control_enabled from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
logger = get_logger("get_authenticated_user") logger = get_logger("get_authenticated_user")
@ -13,7 +15,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement # Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = ( REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or backend_access_control_enabled() or is_backend_access_control_enabled()
) )
fastapi_users = get_fastapi_users() fastapi_users = get_fastapi_users()

View file

@ -1,7 +1,11 @@
import pathlib
import os import os
import cognee import cognee
import pathlib
from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -147,9 +151,7 @@ async def main():
f"{name}: expected single-element list, got {len(search_results)}" f"{name}: expected single-element list, got {len(search_results)}"
) )
from cognee.context_global_variables import backend_access_control_enabled if is_backend_access_control_enabled():
if backend_access_control_enabled():
text = search_results[0]["search_result"][0] text = search_results[0]["search_result"][0]
else: else:
text = search_results[0] text = search_results[0]