fix: move access control check to infra

This commit is contained in:
Boris Arzentar 2025-11-07 16:25:42 +01:00
parent a081f8ba43
commit 2e234e1a87
No known key found for this signature in database
GPG key ID: D5CC274C784807B7
12 changed files with 95 additions and 74 deletions

View file

@ -56,8 +56,8 @@ def upgrade() -> None:
sa.Column("dataset_id", sa.UUID, index=True),
sa.Column("source_node_id", sa.UUID, nullable=False),
sa.Column("destination_node_id", sa.UUID, nullable=False),
sa.Column("label", sa.String()),
sa.Column("relationship_name", sa.String(), nullable=False),
sa.Column("label", sa.Text()),
sa.Column("relationship_name", sa.Text(), nullable=False),
sa.Column("props", sa.JSON()),
sa.Column(
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)

View file

@ -66,7 +66,7 @@ class datasets:
if not dataset:
raise UnauthorizedDataAccessError(f"Dataset {dataset_id} not accessible.")
await delete_dataset_nodes_and_edges(dataset_id)
await delete_dataset_nodes_and_edges(dataset_id, user.id)
return await delete_dataset(dataset)

View file

@ -3,12 +3,11 @@ from contextvars import ContextVar
from typing import Union
from uuid import UUID
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user
# Note: ContextVar allows us to use different graph db configurations in Cognee
# for different async tasks, threads and processes
@ -16,40 +15,11 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
async def set_session_user_context_variable(user):
session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
"""
If backend access control is enabled this function will ensure all datasets have their own databases,
@ -71,11 +41,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config()
if not backend_access_control_enabled():
if not is_backend_access_control_enabled():
return
from cognee.modules.users.methods.get_user import get_user
user = await get_user(user_id)
from cognee.infrastructure.databases.utils.get_or_create_dataset_database import (
get_or_create_dataset_database,
)
# To ensure permissions are enforced properly all datasets will have their own databases
dataset_database = await get_or_create_dataset_database(dataset, user)

View file

@ -0,0 +1,34 @@
import os
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def is_backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False

View file

@ -1,7 +1,7 @@
from uuid import UUID
from typing import Dict, List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger

View file

@ -3,7 +3,8 @@ from typing import Dict, List
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
from cognee.modules.engine.utils import generate_edge_id
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
from cognee.modules.graph.methods import (
delete_dataset_related_edges,
delete_dataset_related_nodes,
@ -12,17 +13,26 @@ from cognee.modules.graph.methods import (
)
async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None:
async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None:
affected_nodes = await get_dataset_related_nodes(dataset_id)
if len(affected_nodes) == 0:
return
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
affected_relationships = await get_dataset_related_edges(dataset_id)
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
non_legacy_nodes = [
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
]
graph_engine = await get_graph_engine()
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
affected_vector_collections: Dict[str, List] = {}
for node in affected_nodes:
for node in non_legacy_nodes:
for indexed_field in node.indexed_fields:
collection_name = f"{node.type}_{indexed_field}"
if collection_name not in affected_vector_collections:
@ -30,17 +40,22 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None:
affected_vector_collections[collection_name].append(node)
vector_engine = get_vector_engine()
for affected_collection, affected_nodes in affected_vector_collections.items():
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
await vector_engine.delete_data_points(
affected_collection, [node.id for node in affected_nodes]
affected_collection, [node.id for node in non_legacy_nodes]
)
affected_relationships = await get_dataset_related_edges(dataset_id)
if len(affected_relationships) > 0:
non_legacy_relationships = [
edge
for index, edge in enumerate(affected_relationships)
if not is_legacy_relationship[index]
]
await vector_engine.delete_data_points(
"EdgeType_relationship_name",
[generate_edge_id(edge.relationship_name) for edge in affected_relationships],
)
await vector_engine.delete_data_points(
"EdgeType_relationship_name",
[str(relationship.slug) for relationship in non_legacy_relationships],
)
await delete_dataset_related_nodes(dataset_id)
await delete_dataset_related_edges(dataset_id)

View file

@ -5,9 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert
from cognee.modules.engine.utils import generate_edge_id
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models.Edge import Edge
from .set_current_user import set_current_user
from cognee.infrastructure.databases.relational.with_async_session import with_async_session
@with_async_session
@ -25,10 +24,6 @@ async def upsert_edges(
-----------
- edges (list): A list of edges to be added to the graph.
"""
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
edges_to_add = []
for edge in edges:

View file

@ -3,10 +3,9 @@ from uuid import NAMESPACE_OID, UUID, uuid5
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import insert
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational import with_async_session
from cognee.modules.graph.models import Node
from .set_current_user import set_current_user
from cognee.infrastructure.engine.models.DataPoint import DataPoint
from cognee.infrastructure.databases.relational.with_async_session import with_async_session
@with_async_session
@ -20,10 +19,6 @@ async def upsert_nodes(
-----------
- nodes (list): A list of nodes to be added to the graph.
"""
if session.get_bind().dialect.name == "postgresql":
# Set the session-level RLS variable
await set_current_user(session, user_id)
upsert_statement = (
insert(Node)
.values(

View file

@ -2,9 +2,9 @@ from datetime import datetime, timezone
from sqlalchemy import (
# event,
DateTime,
String,
JSON,
UUID,
Text,
)
# from sqlalchemy.schema import DDL
@ -29,9 +29,9 @@ class Edge(Base):
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
relationship_name: Mapped[str | None] = mapped_column(String(255), nullable=False)
relationship_name: Mapped[str | None] = mapped_column(Text, nullable=False)
label: Mapped[str | None] = mapped_column(String(255))
label: Mapped[str | None] = mapped_column(Text)
attributes: Mapped[dict | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(

View file

@ -4,11 +4,13 @@ from uuid import UUID
from fastapi.encoders import jsonable_encoder
from typing import Any, List, Optional, Tuple, Type, Union
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +76,7 @@ async def search(
)
# Use search function filtered by permissions if access control is enabled
if backend_access_control_enabled():
if is_backend_access_control_enabled():
search_results = await authorized_search(
query_type=query_type,
query_text=query_text,
@ -156,7 +158,7 @@ async def search(
)
else:
# This is for maintaining backwards compatibility
if backend_access_control_enabled():
if is_backend_access_control_enabled():
return_value = []
for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,7 +5,9 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import backend_access_control_enabled
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
logger = get_logger("get_authenticated_user")
@ -13,7 +15,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or backend_access_control_enabled()
or is_backend_access_control_enabled()
)
fastapi_users = get_fastapi_users()

View file

@ -1,7 +1,11 @@
import pathlib
import os
import cognee
import pathlib
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
is_backend_access_control_enabled,
)
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
from cognee.modules.graph.utils import resolve_edges_to_text
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
@ -147,9 +151,7 @@ async def main():
f"{name}: expected single-element list, got {len(search_results)}"
)
from cognee.context_global_variables import backend_access_control_enabled
if backend_access_control_enabled():
if is_backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]