fix: move access control check to infra
This commit is contained in:
parent
a081f8ba43
commit
2e234e1a87
12 changed files with 95 additions and 74 deletions
|
|
@ -56,8 +56,8 @@ def upgrade() -> None:
|
|||
sa.Column("dataset_id", sa.UUID, index=True),
|
||||
sa.Column("source_node_id", sa.UUID, nullable=False),
|
||||
sa.Column("destination_node_id", sa.UUID, nullable=False),
|
||||
sa.Column("label", sa.String()),
|
||||
sa.Column("relationship_name", sa.String(), nullable=False),
|
||||
sa.Column("label", sa.Text()),
|
||||
sa.Column("relationship_name", sa.Text(), nullable=False),
|
||||
sa.Column("props", sa.JSON()),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class datasets:
|
|||
if not dataset:
|
||||
raise UnauthorizedDataAccessError(f"Dataset {dataset_id} not accessible.")
|
||||
|
||||
await delete_dataset_nodes_and_edges(dataset_id)
|
||||
await delete_dataset_nodes_and_edges(dataset_id, user.id)
|
||||
|
||||
return await delete_dataset(dataset)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ from contextvars import ContextVar
|
|||
from typing import Union
|
||||
from uuid import UUID
|
||||
|
||||
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||
is_backend_access_control_enabled,
|
||||
)
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
||||
# Note: ContextVar allows us to use different graph db configurations in Cognee
|
||||
# for different async tasks, threads and processes
|
||||
|
|
@ -16,40 +15,11 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
vector_dbs_with_multi_user_support = ["lancedb"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
|
||||
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
|
||||
)
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
@ -71,11 +41,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not backend_access_control_enabled():
|
||||
if not is_backend_access_control_enabled():
|
||||
return
|
||||
|
||||
from cognee.modules.users.methods.get_user import get_user
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
||||
from cognee.infrastructure.databases.utils.get_or_create_dataset_database import (
|
||||
get_or_create_dataset_database,
|
||||
)
|
||||
|
||||
# To ensure permissions are enforced properly all datasets will have their own databases
|
||||
dataset_database = await get_or_create_dataset_database(dataset, user)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
import os
|
||||
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
|
||||
|
||||
vector_dbs_with_multi_user_support = ["lancedb"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu"]
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
|
||||
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
|
||||
)
|
||||
|
||||
|
||||
def is_backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from uuid import UUID
|
||||
from typing import Dict, List
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
||||
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from typing import Dict, List
|
|||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger
|
||||
from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger
|
||||
from cognee.modules.graph.methods import (
|
||||
delete_dataset_related_edges,
|
||||
delete_dataset_related_nodes,
|
||||
|
|
@ -12,17 +13,26 @@ from cognee.modules.graph.methods import (
|
|||
)
|
||||
|
||||
|
||||
async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None:
|
||||
async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None:
|
||||
affected_nodes = await get_dataset_related_nodes(dataset_id)
|
||||
|
||||
if len(affected_nodes) == 0:
|
||||
return
|
||||
|
||||
is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id)
|
||||
|
||||
affected_relationships = await get_dataset_related_edges(dataset_id)
|
||||
is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id)
|
||||
|
||||
non_legacy_nodes = [
|
||||
node for index, node in enumerate(affected_nodes) if not is_legacy_node[index]
|
||||
]
|
||||
|
||||
graph_engine = await get_graph_engine()
|
||||
await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes])
|
||||
await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes])
|
||||
|
||||
affected_vector_collections: Dict[str, List] = {}
|
||||
for node in affected_nodes:
|
||||
for node in non_legacy_nodes:
|
||||
for indexed_field in node.indexed_fields:
|
||||
collection_name = f"{node.type}_{indexed_field}"
|
||||
if collection_name not in affected_vector_collections:
|
||||
|
|
@ -30,17 +40,22 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None:
|
|||
affected_vector_collections[collection_name].append(node)
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
for affected_collection, affected_nodes in affected_vector_collections.items():
|
||||
for affected_collection, non_legacy_nodes in affected_vector_collections.items():
|
||||
await vector_engine.delete_data_points(
|
||||
affected_collection, [node.id for node in affected_nodes]
|
||||
affected_collection, [node.id for node in non_legacy_nodes]
|
||||
)
|
||||
|
||||
affected_relationships = await get_dataset_related_edges(dataset_id)
|
||||
if len(affected_relationships) > 0:
|
||||
non_legacy_relationships = [
|
||||
edge
|
||||
for index, edge in enumerate(affected_relationships)
|
||||
if not is_legacy_relationship[index]
|
||||
]
|
||||
|
||||
await vector_engine.delete_data_points(
|
||||
"EdgeType_relationship_name",
|
||||
[generate_edge_id(edge.relationship_name) for edge in affected_relationships],
|
||||
)
|
||||
await vector_engine.delete_data_points(
|
||||
"EdgeType_relationship_name",
|
||||
[str(relationship.slug) for relationship in non_legacy_relationships],
|
||||
)
|
||||
|
||||
await delete_dataset_related_nodes(dataset_id)
|
||||
await delete_dataset_related_edges(dataset_id)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.dialects.postgresql import insert
|
||||
from cognee.modules.engine.utils import generate_edge_id
|
||||
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models.Edge import Edge
|
||||
from .set_current_user import set_current_user
|
||||
from cognee.infrastructure.databases.relational.with_async_session import with_async_session
|
||||
|
||||
|
||||
@with_async_session
|
||||
|
|
@ -25,10 +24,6 @@ async def upsert_edges(
|
|||
-----------
|
||||
- edges (list): A list of edges to be added to the graph.
|
||||
"""
|
||||
if session.get_bind().dialect.name == "postgresql":
|
||||
# Set the session-level RLS variable
|
||||
await set_current_user(session, user_id)
|
||||
|
||||
edges_to_add = []
|
||||
|
||||
for edge in edges:
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@ from uuid import NAMESPACE_OID, UUID, uuid5
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.databases.relational import with_async_session
|
||||
from cognee.modules.graph.models import Node
|
||||
from .set_current_user import set_current_user
|
||||
from cognee.infrastructure.engine.models.DataPoint import DataPoint
|
||||
from cognee.infrastructure.databases.relational.with_async_session import with_async_session
|
||||
|
||||
|
||||
@with_async_session
|
||||
|
|
@ -20,10 +19,6 @@ async def upsert_nodes(
|
|||
-----------
|
||||
- nodes (list): A list of nodes to be added to the graph.
|
||||
"""
|
||||
if session.get_bind().dialect.name == "postgresql":
|
||||
# Set the session-level RLS variable
|
||||
await set_current_user(session, user_id)
|
||||
|
||||
upsert_statement = (
|
||||
insert(Node)
|
||||
.values(
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ from datetime import datetime, timezone
|
|||
from sqlalchemy import (
|
||||
# event,
|
||||
DateTime,
|
||||
String,
|
||||
JSON,
|
||||
UUID,
|
||||
Text,
|
||||
)
|
||||
|
||||
# from sqlalchemy.schema import DDL
|
||||
|
|
@ -29,9 +29,9 @@ class Edge(Base):
|
|||
source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False)
|
||||
|
||||
relationship_name: Mapped[str | None] = mapped_column(String(255), nullable=False)
|
||||
relationship_name: Mapped[str | None] = mapped_column(Text, nullable=False)
|
||||
|
||||
label: Mapped[str | None] = mapped_column(String(255))
|
||||
label: Mapped[str | None] = mapped_column(Text)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSON)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@ from uuid import UUID
|
|||
from fastapi.encoders import jsonable_encoder
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||
is_backend_access_control_enabled,
|
||||
)
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -74,7 +76,7 @@ async def search(
|
|||
)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if backend_access_control_enabled():
|
||||
if is_backend_access_control_enabled():
|
||||
search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -156,7 +158,7 @@ async def search(
|
|||
)
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if backend_access_control_enabled():
|
||||
if is_backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ from ..models import User
|
|||
from ..get_fastapi_users import get_fastapi_users
|
||||
from .get_default_user import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||
is_backend_access_control_enabled,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("get_authenticated_user")
|
||||
|
|
@ -13,7 +15,7 @@ logger = get_logger("get_authenticated_user")
|
|||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or backend_access_control_enabled()
|
||||
or is_backend_access_control_enabled()
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import pathlib
|
||||
import os
|
||||
import cognee
|
||||
import pathlib
|
||||
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
from cognee.infrastructure.environment.config.is_backend_access_control_enabled import (
|
||||
is_backend_access_control_enabled,
|
||||
)
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
from cognee.modules.graph.utils import resolve_edges_to_text
|
||||
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||
|
|
@ -147,9 +151,7 @@ async def main():
|
|||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if backend_access_control_enabled():
|
||||
if is_backend_access_control_enabled():
|
||||
text = search_results[0]["search_result"][0]
|
||||
else:
|
||||
text = search_results[0]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue