refactor: Enable multi user mode by default if graph and vector db pr… (#1695)

…oviders support it

<!-- .github/pull_request_template.md -->

## Description
Enable multi user mode by default for supported graph and vector DBs

## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [ ] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [x] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [ ] **I have tested my changes thoroughly before submitting this PR**
- [ ] **This PR contains minimal changes necessary to address the
issue/feature**
- [ ] My code follows the project's coding standards and style
guidelines
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added necessary documentation (if applicable)
- [ ] All new and existing tests pass
- [ ] I have searched existing PRs to ensure this change hasn't been
submitted already
- [ ] I have linked any relevant issues in the description
- [ ] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
Vasilije 2025-11-06 08:40:06 +01:00 committed by GitHub
commit 8cc55ac0b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 84 additions and 92 deletions

View file

@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False
# Vector: LanceDB # Vector: LanceDB
# Graph: KuzuDB # Graph: KuzuDB
# #
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset # It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
ENABLE_BACKEND_ACCESS_CONTROL=False # Disable mode when using not supported graph/vector databases.
ENABLE_BACKEND_ACCESS_CONTROL=True
################################################################################ ################################################################################
# ☁️ Cloud Sync Settings # ☁️ Cloud Sync Settings

View file

@ -84,6 +84,7 @@ jobs:
GRAPH_DATABASE_PROVIDER: 'neo4j' GRAPH_DATABASE_PROVIDER: 'neo4j'
VECTOR_DB_PROVIDER: 'lancedb' VECTOR_DB_PROVIDER: 'lancedb'
DB_PROVIDER: 'sqlite' DB_PROVIDER: 'sqlite'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
@ -135,6 +136,7 @@ jobs:
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu' GRAPH_DATABASE_PROVIDER: 'kuzu'
VECTOR_DB_PROVIDER: 'pgvector' VECTOR_DB_PROVIDER: 'pgvector'
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_PROVIDER: 'postgres' DB_PROVIDER: 'postgres'
DB_NAME: 'cognee_db' DB_NAME: 'cognee_db'
DB_HOST: '127.0.0.1' DB_HOST: '127.0.0.1'
@ -197,6 +199,7 @@ jobs:
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }} GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }} GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }} GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
DB_NAME: cognee_db DB_NAME: cognee_db
DB_HOST: 127.0.0.1 DB_HOST: 127.0.0.1
DB_PORT: 5432 DB_PORT: 5432

View file

@ -4,6 +4,8 @@ from typing import Union
from uuid import UUID from uuid import UUID
from cognee.base_config import get_base_config from cognee.base_config import get_base_config
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
from cognee.infrastructure.databases.graph.config import get_graph_context_config
from cognee.infrastructure.databases.utils import get_or_create_dataset_database from cognee.infrastructure.databases.utils import get_or_create_dataset_database
from cognee.infrastructure.files.storage.config import file_storage_config from cognee.infrastructure.files.storage.config import file_storage_config
from cognee.modules.users.methods import get_user from cognee.modules.users.methods import get_user
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
graph_db_config = ContextVar("graph_db_config", default=None) graph_db_config = ContextVar("graph_db_config", default=None)
session_user = ContextVar("session_user", default=None) session_user = ContextVar("session_user", default=None)
vector_dbs_with_multi_user_support = ["lancedb"]
graph_dbs_with_multi_user_support = ["kuzu"]
async def set_session_user_context_variable(user): async def set_session_user_context_variable(user):
session_user.set(user) session_user.set(user)
def multi_user_support_possible():
graph_db_config = get_graph_context_config()
vector_db_config = get_vectordb_context_config()
return (
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
)
def backend_access_control_enabled():
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
if backend_access_control is None:
# If backend access control is not defined in environment variables,
# enable it by default if graph and vector DBs can support it, otherwise disable it
return multi_user_support_possible()
elif backend_access_control.lower() == "true":
# If enabled, ensure that the current graph and vector DBs can support it
multi_user_support = multi_user_support_possible()
if not multi_user_support:
raise EnvironmentError(
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
)
return True
return False
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID): async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
""" """
If backend access control is enabled this function will ensure all datasets have their own databases, If backend access control is enabled this function will ensure all datasets have their own databases,
@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
base_config = get_base_config() base_config = get_base_config()
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if not backend_access_control_enabled():
return return
user = await get_user(user_id) user = await get_user(user_id)

View file

@ -1,4 +1,3 @@
import os
import json import json
import asyncio import asyncio
from uuid import UUID from uuid import UUID
@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry from cognee.shared.utils import send_telemetry
from cognee.context_global_variables import set_database_global_context_variables from cognee.context_global_variables import set_database_global_context_variables
from cognee.context_global_variables import backend_access_control_enabled
from cognee.modules.engine.models.node_set import NodeSet from cognee.modules.engine.models.node_set import NodeSet
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
@ -74,7 +74,7 @@ async def search(
) )
# Use search function filtered by permissions if access control is enabled # Use search function filtered by permissions if access control is enabled
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if backend_access_control_enabled():
search_results = await authorized_search( search_results = await authorized_search(
query_type=query_type, query_type=query_type,
query_text=query_text, query_text=query_text,
@ -156,7 +156,7 @@ async def search(
) )
else: else:
# This is for maintaining backwards compatibility # This is for maintaining backwards compatibility
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true": if backend_access_control_enabled():
return_value = [] return_value = []
for search_result in search_results: for search_result in search_results:
prepared_search_results = await prepare_search_result(search_result) prepared_search_results = await prepare_search_result(search_result)

View file

@ -5,6 +5,7 @@ from ..models import User
from ..get_fastapi_users import get_fastapi_users from ..get_fastapi_users import get_fastapi_users
from .get_default_user import get_default_user from .get_default_user import get_default_user
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.context_global_variables import backend_access_control_enabled
logger = get_logger("get_authenticated_user") logger = get_logger("get_authenticated_user")
@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
# Check environment variable to determine authentication requirement # Check environment variable to determine authentication requirement
REQUIRE_AUTHENTICATION = ( REQUIRE_AUTHENTICATION = (
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true" or backend_access_control_enabled()
) )
fastapi_users = get_fastapi_users() fastapi_users = get_fastapi_users()

View file

@ -39,12 +39,12 @@ async def main():
answer = await cognee.search("Do programmers change light bulbs?") answer = await cognee.search("Do programmers change light bulbs?")
assert len(answer) != 0 assert len(answer) != 0
lowercase_answer = answer[0].lower() lowercase_answer = answer[0]["search_result"][0].lower()
assert ("no" in lowercase_answer) or ("none" in lowercase_answer) assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
answer = await cognee.search("What colours are there in the presentation table?") answer = await cognee.search("What colours are there in the presentation table?")
assert len(answer) != 0 assert len(answer) != 0
lowercase_answer = answer[0].lower() lowercase_answer = answer[0]["search_result"][0].lower()
assert ( assert (
("red" in lowercase_answer) ("red" in lowercase_answer)
and ("blue" in lowercase_answer) and ("blue" in lowercase_answer)

View file

@ -133,7 +133,7 @@ async def main():
extraction_tasks=extraction_tasks, extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks, enrichment_tasks=enrichment_tasks,
data=[{}], data=[{}],
dataset="feedback_enrichment_test_memify", dataset=dataset_name,
) )
nodes_after, edges_after = await graph_engine.get_graph_data() nodes_after, edges_after = await graph_engine.get_graph_data()

View file

@ -90,15 +90,17 @@ async def main():
) )
search_results = await cognee.search( search_results = await cognee.search(
query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?" query_type=SearchType.GRAPH_COMPLETION,
query_text="What information do you contain?",
dataset_ids=[pipeline_run_obj.dataset_id],
) )
assert "Mark" in search_results[0], ( assert "Mark" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Mark in search results" "Failed to update document, no mention of Mark in search results"
) )
assert "Cindy" in search_results[0], ( assert "Cindy" in search_results[0]["search_result"][0], (
"Failed to update document, no mention of Cindy in search results" "Failed to update document, no mention of Cindy in search results"
) )
assert "Artificial intelligence" not in search_results[0], ( assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
"Failed to update document, Artificial intelligence still mentioned in search results" "Failed to update document, Artificial intelligence still mentioned in search results"
) )

View file

@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str:
async def setup_test_db(): async def setup_test_db():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -146,7 +146,13 @@ async def main():
assert len(search_results) == 1, ( assert len(search_results) == 1, (
f"{name}: expected single-element list, got {len(search_results)}" f"{name}: expected single-element list, got {len(search_results)}"
) )
text = search_results[0]
from cognee.context_global_variables import backend_access_control_enabled
if backend_access_control_enabled():
text = search_results[0]["search_result"][0]
else:
text = search_results[0]
assert isinstance(text, str), f"{name}: element should be a string" assert isinstance(text, str), f"{name}: element should be a string"
assert text.strip(), f"{name}: string should not be empty" assert text.strip(), f"{name}: string should not be empty"
assert "netherlands" in text.lower(), ( assert "netherlands" in text.lower(), (

View file

@ -1,3 +1,4 @@
import os
import pytest import pytest
from unittest.mock import patch, AsyncMock, MagicMock from unittest.mock import patch, AsyncMock, MagicMock
from uuid import uuid4 from uuid import uuid4
@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
from types import SimpleNamespace from types import SimpleNamespace
import importlib import importlib
from cognee.api.client import app
# Fixtures for reuse across test classes # Fixtures for reuse across test classes
@pytest.fixture @pytest.fixture
@ -32,6 +31,10 @@ def mock_authenticated_user():
) )
# To turn off authentication we need to set the environment variable before importing the module
# Also both require_authentication and backend access control must be false
os.environ["REQUIRE_AUTHENTICATION"] = "false"
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
@pytest.fixture @pytest.fixture
def client(self): def client(self):
from cognee.api.client import app
"""Create a test client.""" """Create a test client."""
return TestClient(app) return TestClient(app)
@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
@pytest.fixture @pytest.fixture
def client(self): def client(self):
from cognee.api.client import app
return TestClient(app) return TestClient(app)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
@pytest.fixture @pytest.fixture
def client(self): def client(self):
from cognee.api.client import app
return TestClient(app) return TestClient(app)
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
# The exact error message may vary depending on the actual database connection # The exact error message may vary depending on the actual database connection
# The important thing is that we get a 500 error when user creation fails # The important thing is that we get a 500 error when user creation fails
def test_current_environment_configuration(self): def test_current_environment_configuration(self, client):
"""Test that current environment configuration is working properly.""" """Test that current environment configuration is working properly."""
# This tests the actual module state without trying to change it # This tests the actual module state without trying to change it
from cognee.modules.users.methods.get_authenticated_user import ( from cognee.modules.users.methods.get_authenticated_user import (

View file

@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
# REQUIRE_AUTHENTICATION should be a boolean # REQUIRE_AUTHENTICATION should be a boolean
assert isinstance(REQUIRE_AUTHENTICATION, bool) assert isinstance(REQUIRE_AUTHENTICATION, bool)
# Currently should be False (optional authentication)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEnvironmentVariables: class TestConditionalAuthenticationEnvironmentVariables:
"""Test environment variable handling.""" """Test environment variable handling."""
def test_require_authentication_default_false(self):
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
with patch.dict(os.environ, {}, clear=True):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see empty environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
importlib.invalidate_caches()
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_true(self): def test_require_authentication_true(self):
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported.""" """Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}): with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
assert REQUIRE_AUTHENTICATION assert REQUIRE_AUTHENTICATION
def test_require_authentication_false_explicit(self):
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
assert not REQUIRE_AUTHENTICATION
def test_require_authentication_case_insensitive(self):
"""Test that environment variable parsing is case insensitive when imported."""
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
for case in test_cases:
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
# Remove module from cache to force fresh import
module_name = "cognee.modules.users.methods.get_authenticated_user"
if module_name in sys.modules:
del sys.modules[module_name]
# Import after patching environment
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
expected = case.lower() == "true"
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
def test_current_require_authentication_value(self):
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
from cognee.modules.users.methods.get_authenticated_user import (
REQUIRE_AUTHENTICATION,
)
# The module-level variable should currently be False (set at import time)
assert isinstance(REQUIRE_AUTHENTICATION, bool)
assert not REQUIRE_AUTHENTICATION
class TestConditionalAuthenticationEdgeCases: class TestConditionalAuthenticationEdgeCases:
"""Test edge cases and error scenarios.""" """Test edge cases and error scenarios."""

View file

@ -168,7 +168,7 @@ async def run_procurement_example():
for q in questions: for q in questions:
print(f"Question: \n{q}") print(f"Question: \n{q}")
results = await procurement_system.search_memory(q, search_categories=[category]) results = await procurement_system.search_memory(q, search_categories=[category])
top_answer = results[category][0] top_answer = results[category][0]["search_result"][0]
print(f"Answer: \n{top_answer.strip()}\n") print(f"Answer: \n{top_answer.strip()}\n")
research_notes[category].append({"question": q, "answer": top_answer}) research_notes[category].append({"question": q, "answer": top_answer})

View file

@ -1,5 +1,7 @@
import argparse import argparse
import asyncio import asyncio
import os
import cognee import cognee
from cognee import SearchType from cognee import SearchType
from cognee.shared.logging_utils import setup_logging, ERROR from cognee.shared.logging_utils import setup_logging, ERROR
@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
async def main(repo_path, include_docs): async def main(repo_path, include_docs):
# Disable permissions feature for this example
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
run_status = False run_status = False
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs): async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
run_status = run_status run_status = run_status

View file

@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
extraction_tasks=extraction_tasks, extraction_tasks=extraction_tasks,
enrichment_tasks=enrichment_tasks, enrichment_tasks=enrichment_tasks,
data=[{}], # A placeholder to prevent fetching the entire graph data=[{}], # A placeholder to prevent fetching the entire graph
dataset="feedback_enrichment_minimal",
) )

View file

@ -89,7 +89,7 @@ async def main():
) )
print("Coding rules created by memify:") print("Coding rules created by memify:")
for coding_rule in coding_rules: for coding_rule in coding_rules[0]["search_result"][0]:
print("- " + coding_rule) print("- " + coding_rule)
# Visualize new graph with added memify context # Visualize new graph with added memify context

View file

@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
async def main(): async def main():
# Disable backend access control to migrate relational data
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
# Clean all data stored in Cognee # Clean all data stored in Cognee
await cognee.prune.prune_data() await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True) await cognee.prune.prune_system(metadata=True)

View file

@ -59,14 +59,6 @@ async def main():
for result_text in search_results: for result_text in search_results:
print(result_text) print(result_text)
# Example output:
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
# (...)
# It represents nodes and relationships in the knowledge graph:
# - The first element is the source node (e.g., 'natural language processing').
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
# - The third element is the target node (e.g., 'computer science').
if __name__ == "__main__": if __name__ == "__main__":
logger = setup_logging(log_level=ERROR) logger = setup_logging(log_level=ERROR)