diff --git a/alembic/versions/84e5d08260d6_replace_graph_ledger_table_with_nodes_.py b/alembic/versions/84e5d08260d6_replace_graph_ledger_table_with_nodes_.py index 44ef29d59..919b5bfb4 100644 --- a/alembic/versions/84e5d08260d6_replace_graph_ledger_table_with_nodes_.py +++ b/alembic/versions/84e5d08260d6_replace_graph_ledger_table_with_nodes_.py @@ -56,8 +56,8 @@ def upgrade() -> None: sa.Column("dataset_id", sa.UUID, index=True), sa.Column("source_node_id", sa.UUID, nullable=False), sa.Column("destination_node_id", sa.UUID, nullable=False), - sa.Column("label", sa.String()), - sa.Column("relationship_name", sa.String(), nullable=False), + sa.Column("label", sa.Text()), + sa.Column("relationship_name", sa.Text(), nullable=False), sa.Column("props", sa.JSON()), sa.Column( "created_at", sa.DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) diff --git a/cognee/api/v1/datasets/datasets.py b/cognee/api/v1/datasets/datasets.py index 5a6293240..0f2b5dcd5 100644 --- a/cognee/api/v1/datasets/datasets.py +++ b/cognee/api/v1/datasets/datasets.py @@ -66,7 +66,7 @@ class datasets: if not dataset: raise UnauthorizedDataAccessError(f"Dataset {dataset_id} not accessible.") - await delete_dataset_nodes_and_edges(dataset_id) + await delete_dataset_nodes_and_edges(dataset_id, user.id) return await delete_dataset(dataset) diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index f17c9187a..6e2acf400 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -3,12 +3,11 @@ from contextvars import ContextVar from typing import Union from uuid import UUID +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_backend_access_control_enabled, +) from cognee.base_config import get_base_config -from cognee.infrastructure.databases.vector.config import get_vectordb_context_config -from cognee.infrastructure.databases.graph.config import get_graph_context_config -from cognee.infrastructure.databases.utils import get_or_create_dataset_database from cognee.infrastructure.files.storage.config import file_storage_config -from cognee.modules.users.methods import get_user # Note: ContextVar allows us to use different graph db configurations in Cognee # for different async tasks, threads and processes @@ -16,40 +15,11 @@ vector_db_config = ContextVar("vector_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None) session_user = ContextVar("session_user", default=None) -vector_dbs_with_multi_user_support = ["lancedb"] -graph_dbs_with_multi_user_support = ["kuzu"] - async def set_session_user_context_variable(user): session_user.set(user) -def multi_user_support_possible(): - graph_db_config = get_graph_context_config() - vector_db_config = get_vectordb_context_config() - return ( - graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support - and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support - ) - - -def backend_access_control_enabled(): - backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) - if backend_access_control is None: - # If backend access control is not defined in environment variables, - # enable it by default if graph and vector DBs can support it, otherwise disable it - return multi_user_support_possible() - elif backend_access_control.lower() == "true": - # If enabled, ensure that the current graph and vector DBs can support it - multi_user_support = multi_user_support_possible() - if not multi_user_support: - raise EnvironmentError( - "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." - ) - return True - return False - - async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): """ If backend access control is enabled this function will ensure all datasets have their own databases, @@ -71,11 +41,17 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ base_config = get_base_config() - if not backend_access_control_enabled(): + if not is_backend_access_control_enabled(): return + from cognee.modules.users.methods.get_user import get_user + user = await get_user(user_id) + from cognee.infrastructure.databases.utils.get_or_create_dataset_database import ( + get_or_create_dataset_database, + ) + # To ensure permissions are enforced properly all datasets will have their own databases dataset_database = await get_or_create_dataset_database(dataset, user) diff --git a/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py b/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py new file mode 100644 index 000000000..0aed55119 --- /dev/null +++ b/cognee/infrastructure/environment/config/is_backend_access_control_enabled.py @@ -0,0 +1,34 @@ +import os + +from cognee.infrastructure.databases.vector.config import get_vectordb_context_config +from cognee.infrastructure.databases.graph.config import get_graph_context_config + + +vector_dbs_with_multi_user_support = ["lancedb"] +graph_dbs_with_multi_user_support = ["kuzu"] + + +def multi_user_support_possible(): + graph_db_config = get_graph_context_config() + vector_db_config = get_vectordb_context_config() + return ( + graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support + and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support + ) + + +def is_backend_access_control_enabled(): + backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None) + if backend_access_control is None: + # If backend access control is not defined in environment variables, + # enable it by default if graph and vector DBs can support it, otherwise disable it + return multi_user_support_possible() + elif backend_access_control.lower() == "true": + # If enabled, ensure that the current graph and vector DBs can support it + multi_user_support = multi_user_support_possible() + if not multi_user_support: + raise EnvironmentError( + "ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control." + ) + return True + return False diff --git a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py index d96d8f265..1ce56a313 100644 --- a/cognee/modules/graph/methods/delete_data_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_data_nodes_and_edges.py @@ -1,7 +1,7 @@ from uuid import UUID from typing import Dict, List -from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.databases.graph.get_graph_engine import get_graph_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger diff --git a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py index b11db3d56..391816a13 100644 --- a/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py +++ b/cognee/modules/graph/methods/delete_dataset_nodes_and_edges.py @@ -3,7 +3,8 @@ from typing import Dict, List from cognee.infrastructure.databases.graph import get_graph_engine from cognee.infrastructure.databases.vector.get_vector_engine import get_vector_engine -from cognee.modules.engine.utils import generate_edge_id +from cognee.modules.graph.legacy.has_nodes_in_legacy_ledger import has_nodes_in_legacy_ledger +from cognee.modules.graph.legacy.has_edges_in_legacy_ledger import has_edges_in_legacy_ledger from cognee.modules.graph.methods import ( delete_dataset_related_edges, delete_dataset_related_nodes, @@ -12,17 +13,26 @@ from cognee.modules.graph.methods import ( ) -async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None: +async def delete_dataset_nodes_and_edges(dataset_id: UUID, user_id: UUID) -> None: affected_nodes = await get_dataset_related_nodes(dataset_id) if len(affected_nodes) == 0: return + is_legacy_node = await has_nodes_in_legacy_ledger(affected_nodes, user_id) + + affected_relationships = await get_dataset_related_edges(dataset_id) + is_legacy_relationship = await has_edges_in_legacy_ledger(affected_relationships, user_id) + + non_legacy_nodes = [ + node for index, node in enumerate(affected_nodes) if not is_legacy_node[index] + ] + graph_engine = await get_graph_engine() - await graph_engine.delete_nodes([str(node.slug) for node in affected_nodes]) + await graph_engine.delete_nodes([str(node.slug) for node in non_legacy_nodes]) affected_vector_collections: Dict[str, List] = {} - for node in affected_nodes: + for node in non_legacy_nodes: for indexed_field in node.indexed_fields: collection_name = f"{node.type}_{indexed_field}" if collection_name not in affected_vector_collections: @@ -30,17 +40,22 @@ async def delete_dataset_nodes_and_edges(dataset_id: UUID) -> None: affected_vector_collections[collection_name].append(node) vector_engine = get_vector_engine() - for affected_collection, affected_nodes in affected_vector_collections.items(): + for affected_collection, non_legacy_nodes in affected_vector_collections.items(): await vector_engine.delete_data_points( - affected_collection, [node.id for node in affected_nodes] + affected_collection, [node.id for node in non_legacy_nodes] ) - affected_relationships = await get_dataset_related_edges(dataset_id) + if len(affected_relationships) > 0: + non_legacy_relationships = [ + edge + for index, edge in enumerate(affected_relationships) + if not is_legacy_relationship[index] + ] - await vector_engine.delete_data_points( - "EdgeType_relationship_name", - [generate_edge_id(edge.relationship_name) for edge in affected_relationships], - ) + await vector_engine.delete_data_points( + "EdgeType_relationship_name", + [str(relationship.slug) for relationship in non_legacy_relationships], + ) await delete_dataset_related_nodes(dataset_id) await delete_dataset_related_edges(dataset_id) diff --git a/cognee/modules/graph/methods/upsert_edges.py b/cognee/modules/graph/methods/upsert_edges.py index d17367f44..0029b8b24 100644 --- a/cognee/modules/graph/methods/upsert_edges.py +++ b/cognee/modules/graph/methods/upsert_edges.py @@ -5,9 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.dialects.postgresql import insert from cognee.modules.engine.utils import generate_edge_id -from cognee.infrastructure.databases.relational import with_async_session from cognee.modules.graph.models.Edge import Edge -from .set_current_user import set_current_user +from cognee.infrastructure.databases.relational.with_async_session import with_async_session @with_async_session @@ -25,10 +24,6 @@ async def upsert_edges( ----------- - edges (list): A list of edges to be added to the graph. """ - if session.get_bind().dialect.name == "postgresql": - # Set the session-level RLS variable - await set_current_user(session, user_id) - edges_to_add = [] for edge in edges: diff --git a/cognee/modules/graph/methods/upsert_nodes.py b/cognee/modules/graph/methods/upsert_nodes.py index eeb159c84..181171d48 100644 --- a/cognee/modules/graph/methods/upsert_nodes.py +++ b/cognee/modules/graph/methods/upsert_nodes.py @@ -3,10 +3,9 @@ from uuid import NAMESPACE_OID, UUID, uuid5 from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.dialects.postgresql import insert -from cognee.infrastructure.engine.models.DataPoint import DataPoint -from cognee.infrastructure.databases.relational import with_async_session from cognee.modules.graph.models import Node -from .set_current_user import set_current_user +from cognee.infrastructure.engine.models.DataPoint import DataPoint +from cognee.infrastructure.databases.relational.with_async_session import with_async_session @with_async_session @@ -20,10 +19,6 @@ async def upsert_nodes( ----------- - nodes (list): A list of nodes to be added to the graph. """ - if session.get_bind().dialect.name == "postgresql": - # Set the session-level RLS variable - await set_current_user(session, user_id) - upsert_statement = ( insert(Node) .values( diff --git a/cognee/modules/graph/models/Edge.py b/cognee/modules/graph/models/Edge.py index c04b205e6..1ed992cab 100644 --- a/cognee/modules/graph/models/Edge.py +++ b/cognee/modules/graph/models/Edge.py @@ -2,9 +2,9 @@ from datetime import datetime, timezone from sqlalchemy import ( # event, DateTime, - String, JSON, UUID, + Text, ) # from sqlalchemy.schema import DDL @@ -29,9 +29,9 @@ class Edge(Base): source_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) destination_node_id: Mapped[UUID] = mapped_column(UUID(as_uuid=True), nullable=False) - relationship_name: Mapped[str | None] = mapped_column(String(255), nullable=False) + relationship_name: Mapped[str | None] = mapped_column(Text, nullable=False) - label: Mapped[str | None] = mapped_column(String(255)) + label: Mapped[str | None] = mapped_column(Text) attributes: Mapped[dict | None] = mapped_column(JSON) created_at: Mapped[datetime] = mapped_column( diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index 5e465b239..d9256ae87 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -4,11 +4,13 @@ from uuid import UUID from fastapi.encoders import jsonable_encoder from typing import Any, List, Optional, Tuple, Type, Union +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_backend_access_control_enabled, +) from cognee.infrastructure.databases.graph import get_graph_engine from cognee.shared.logging_utils import get_logger from cognee.shared.utils import send_telemetry from cognee.context_global_variables import set_database_global_context_variables -from cognee.context_global_variables import backend_access_control_enabled from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge @@ -74,7 +76,7 @@ async def search( ) # Use search function filtered by permissions if access control is enabled - if backend_access_control_enabled(): + if is_backend_access_control_enabled(): search_results = await authorized_search( query_type=query_type, query_text=query_text, @@ -156,7 +158,7 @@ async def search( ) else: # This is for maintaining backwards compatibility - if backend_access_control_enabled(): + if is_backend_access_control_enabled(): return_value = [] for search_result in search_results: prepared_search_results = await prepare_search_result(search_result) diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index d6d701737..90c4a1274 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -5,7 +5,9 @@ from ..models import User from ..get_fastapi_users import get_fastapi_users from .get_default_user import get_default_user from cognee.shared.logging_utils import get_logger -from cognee.context_global_variables import backend_access_control_enabled +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_backend_access_control_enabled, +) logger = get_logger("get_authenticated_user") @@ -13,7 +15,7 @@ logger = get_logger("get_authenticated_user") # Check environment variable to determine authentication requirement REQUIRE_AUTHENTICATION = ( os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" - or backend_access_control_enabled() + or is_backend_access_control_enabled() ) fastapi_users = get_fastapi_users() diff --git a/cognee/tests/test_search_db.py b/cognee/tests/test_search_db.py index bd11dc62e..57e86f6f7 100644 --- a/cognee/tests/test_search_db.py +++ b/cognee/tests/test_search_db.py @@ -1,7 +1,11 @@ -import pathlib import os import cognee +import pathlib + from cognee.infrastructure.databases.graph import get_graph_engine +from cognee.infrastructure.environment.config.is_backend_access_control_enabled import ( + is_backend_access_control_enabled, +) from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.utils import resolve_edges_to_text from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever @@ -147,9 +151,7 @@ async def main(): f"{name}: expected single-element list, got {len(search_results)}" ) - from cognee.context_global_variables import backend_access_control_enabled - - if backend_access_control_enabled(): + if is_backend_access_control_enabled(): text = search_results[0]["search_result"][0] else: text = search_results[0]