refactor: Enable multi user mode by default if graph and vector db pr… (#1695)
…oviders support it <!-- .github/pull_request_template.md --> ## Description Enable multi user mode by default for supported graph and vector DBs ## Type of Change <!-- Please check the relevant option --> - [ ] Bug fix (non-breaking change that fixes an issue) - [ ] New feature (non-breaking change that adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Code refactoring - [ ] Performance improvement - [ ] Other (please specify): ## Screenshots/Videos (if applicable) <!-- Add screenshots or videos to help explain your changes --> ## Pre-submission Checklist <!-- Please check all boxes that apply before submitting your PR --> - [ ] **I have tested my changes thoroughly before submitting this PR** - [ ] **This PR contains minimal changes necessary to address the issue/feature** - [ ] My code follows the project's coding standards and style guidelines - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added necessary documentation (if applicable) - [ ] All new and existing tests pass - [ ] I have searched existing PRs to ensure this change hasn't been submitted already - [ ] I have linked any relevant issues in the description - [ ] My commits have clear and descriptive messages ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
This commit is contained in:
commit
8cc55ac0b2
18 changed files with 84 additions and 92 deletions
|
|
@ -169,8 +169,9 @@ REQUIRE_AUTHENTICATION=False
|
|||
# Vector: LanceDB
|
||||
# Graph: KuzuDB
|
||||
#
|
||||
# It enforces LanceDB and KuzuDB use and uses them to create databases per Cognee user + dataset
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=False
|
||||
# It enforces creation of databases per Cognee user + dataset. Does not work with some graph and database providers.
|
||||
# Disable mode when using not supported graph/vector databases.
|
||||
ENABLE_BACKEND_ACCESS_CONTROL=True
|
||||
|
||||
################################################################################
|
||||
# ☁️ Cloud Sync Settings
|
||||
|
|
|
|||
3
.github/workflows/search_db_tests.yml
vendored
3
.github/workflows/search_db_tests.yml
vendored
|
|
@ -84,6 +84,7 @@ jobs:
|
|||
GRAPH_DATABASE_PROVIDER: 'neo4j'
|
||||
VECTOR_DB_PROVIDER: 'lancedb'
|
||||
DB_PROVIDER: 'sqlite'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
|
|
@ -135,6 +136,7 @@ jobs:
|
|||
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
|
||||
GRAPH_DATABASE_PROVIDER: 'kuzu'
|
||||
VECTOR_DB_PROVIDER: 'pgvector'
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_PROVIDER: 'postgres'
|
||||
DB_NAME: 'cognee_db'
|
||||
DB_HOST: '127.0.0.1'
|
||||
|
|
@ -197,6 +199,7 @@ jobs:
|
|||
GRAPH_DATABASE_URL: ${{ steps.neo4j.outputs.neo4j-url }}
|
||||
GRAPH_DATABASE_USERNAME: ${{ steps.neo4j.outputs.neo4j-username }}
|
||||
GRAPH_DATABASE_PASSWORD: ${{ steps.neo4j.outputs.neo4j-password }}
|
||||
ENABLE_BACKEND_ACCESS_CONTROL: 'false'
|
||||
DB_NAME: cognee_db
|
||||
DB_HOST: 127.0.0.1
|
||||
DB_PORT: 5432
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from typing import Union
|
|||
from uuid import UUID
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_context_config
|
||||
from cognee.infrastructure.databases.graph.config import get_graph_context_config
|
||||
from cognee.infrastructure.databases.utils import get_or_create_dataset_database
|
||||
from cognee.infrastructure.files.storage.config import file_storage_config
|
||||
from cognee.modules.users.methods import get_user
|
||||
|
|
@ -14,11 +16,40 @@ vector_db_config = ContextVar("vector_db_config", default=None)
|
|||
graph_db_config = ContextVar("graph_db_config", default=None)
|
||||
session_user = ContextVar("session_user", default=None)
|
||||
|
||||
vector_dbs_with_multi_user_support = ["lancedb"]
|
||||
graph_dbs_with_multi_user_support = ["kuzu"]
|
||||
|
||||
|
||||
async def set_session_user_context_variable(user):
|
||||
session_user.set(user)
|
||||
|
||||
|
||||
def multi_user_support_possible():
|
||||
graph_db_config = get_graph_context_config()
|
||||
vector_db_config = get_vectordb_context_config()
|
||||
return (
|
||||
graph_db_config["graph_database_provider"] in graph_dbs_with_multi_user_support
|
||||
and vector_db_config["vector_db_provider"] in vector_dbs_with_multi_user_support
|
||||
)
|
||||
|
||||
|
||||
def backend_access_control_enabled():
|
||||
backend_access_control = os.environ.get("ENABLE_BACKEND_ACCESS_CONTROL", None)
|
||||
if backend_access_control is None:
|
||||
# If backend access control is not defined in environment variables,
|
||||
# enable it by default if graph and vector DBs can support it, otherwise disable it
|
||||
return multi_user_support_possible()
|
||||
elif backend_access_control.lower() == "true":
|
||||
# If enabled, ensure that the current graph and vector DBs can support it
|
||||
multi_user_support = multi_user_support_possible()
|
||||
if not multi_user_support:
|
||||
raise EnvironmentError(
|
||||
"ENABLE_BACKEND_ACCESS_CONTROL is set to true but the current graph and/or vector databases do not support multi-user access control. Please use supported databases or disable backend access control."
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def set_database_global_context_variables(dataset: Union[str, UUID], user_id: UUID):
|
||||
"""
|
||||
If backend access control is enabled this function will ensure all datasets have their own databases,
|
||||
|
|
@ -40,7 +71,7 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
|
||||
base_config = get_base_config()
|
||||
|
||||
if not os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if not backend_access_control_enabled():
|
||||
return
|
||||
|
||||
user = await get_user(user_id)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
|
@ -9,6 +8,7 @@ from cognee.infrastructure.databases.graph import get_graph_engine
|
|||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.context_global_variables import set_database_global_context_variables
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
from cognee.modules.engine.models.node_set import NodeSet
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge
|
||||
|
|
@ -74,7 +74,7 @@ async def search(
|
|||
)
|
||||
|
||||
# Use search function filtered by permissions if access control is enabled
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if backend_access_control_enabled():
|
||||
search_results = await authorized_search(
|
||||
query_type=query_type,
|
||||
query_text=query_text,
|
||||
|
|
@ -156,7 +156,7 @@ async def search(
|
|||
)
|
||||
else:
|
||||
# This is for maintaining backwards compatibility
|
||||
if os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true":
|
||||
if backend_access_control_enabled():
|
||||
return_value = []
|
||||
for search_result in search_results:
|
||||
prepared_search_results = await prepare_search_result(search_result)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from ..models import User
|
|||
from ..get_fastapi_users import get_fastapi_users
|
||||
from .get_default_user import get_default_user
|
||||
from cognee.shared.logging_utils import get_logger
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
|
||||
logger = get_logger("get_authenticated_user")
|
||||
|
|
@ -12,7 +13,7 @@ logger = get_logger("get_authenticated_user")
|
|||
# Check environment variable to determine authentication requirement
|
||||
REQUIRE_AUTHENTICATION = (
|
||||
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
|
||||
or backend_access_control_enabled()
|
||||
)
|
||||
|
||||
fastapi_users = get_fastapi_users()
|
||||
|
|
|
|||
|
|
@ -39,12 +39,12 @@ async def main():
|
|||
|
||||
answer = await cognee.search("Do programmers change light bulbs?")
|
||||
assert len(answer) != 0
|
||||
lowercase_answer = answer[0].lower()
|
||||
lowercase_answer = answer[0]["search_result"][0].lower()
|
||||
assert ("no" in lowercase_answer) or ("none" in lowercase_answer)
|
||||
|
||||
answer = await cognee.search("What colours are there in the presentation table?")
|
||||
assert len(answer) != 0
|
||||
lowercase_answer = answer[0].lower()
|
||||
lowercase_answer = answer[0]["search_result"][0].lower()
|
||||
assert (
|
||||
("red" in lowercase_answer)
|
||||
and ("blue" in lowercase_answer)
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ async def main():
|
|||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
data=[{}],
|
||||
dataset="feedback_enrichment_test_memify",
|
||||
dataset=dataset_name,
|
||||
)
|
||||
|
||||
nodes_after, edges_after = await graph_engine.get_graph_data()
|
||||
|
|
|
|||
|
|
@ -90,15 +90,17 @@ async def main():
|
|||
)
|
||||
|
||||
search_results = await cognee.search(
|
||||
query_type=SearchType.GRAPH_COMPLETION, query_text="What information do you contain?"
|
||||
query_type=SearchType.GRAPH_COMPLETION,
|
||||
query_text="What information do you contain?",
|
||||
dataset_ids=[pipeline_run_obj.dataset_id],
|
||||
)
|
||||
assert "Mark" in search_results[0], (
|
||||
assert "Mark" in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, no mention of Mark in search results"
|
||||
)
|
||||
assert "Cindy" in search_results[0], (
|
||||
assert "Cindy" in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, no mention of Cindy in search results"
|
||||
)
|
||||
assert "Artificial intelligence" not in search_results[0], (
|
||||
assert "Artificial intelligence" not in search_results[0]["search_result"][0], (
|
||||
"Failed to update document, Artificial intelligence still mentioned in search results"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ def normalize_node_name(node_name: str) -> str:
|
|||
|
||||
|
||||
async def setup_test_db():
|
||||
# Disable backend access control to migrate relational data
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,13 @@ async def main():
|
|||
assert len(search_results) == 1, (
|
||||
f"{name}: expected single-element list, got {len(search_results)}"
|
||||
)
|
||||
text = search_results[0]
|
||||
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if backend_access_control_enabled():
|
||||
text = search_results[0]["search_result"][0]
|
||||
else:
|
||||
text = search_results[0]
|
||||
assert isinstance(text, str), f"{name}: element should be a string"
|
||||
assert text.strip(), f"{name}: string should not be empty"
|
||||
assert "netherlands" in text.lower(), (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
|
@ -5,8 +6,6 @@ from fastapi.testclient import TestClient
|
|||
from types import SimpleNamespace
|
||||
import importlib
|
||||
|
||||
from cognee.api.client import app
|
||||
|
||||
|
||||
# Fixtures for reuse across test classes
|
||||
@pytest.fixture
|
||||
|
|
@ -32,6 +31,10 @@ def mock_authenticated_user():
|
|||
)
|
||||
|
||||
|
||||
# To turn off authentication we need to set the environment variable before importing the module
|
||||
# Also both require_authentication and backend access control must be false
|
||||
os.environ["REQUIRE_AUTHENTICATION"] = "false"
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
||||
|
||||
|
||||
|
|
@ -40,6 +43,8 @@ class TestConditionalAuthenticationEndpoints:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
"""Create a test client."""
|
||||
return TestClient(app)
|
||||
|
||||
|
|
@ -133,6 +138,8 @@ class TestConditionalAuthenticationBehavior:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -209,6 +216,8 @@ class TestConditionalAuthenticationErrorHandling:
|
|||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from cognee.api.client import app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||
|
|
@ -232,7 +241,7 @@ class TestConditionalAuthenticationErrorHandling:
|
|||
# The exact error message may vary depending on the actual database connection
|
||||
# The important thing is that we get a 500 error when user creation fails
|
||||
|
||||
def test_current_environment_configuration(self):
|
||||
def test_current_environment_configuration(self, client):
|
||||
"""Test that current environment configuration is working properly."""
|
||||
# This tests the actual module state without trying to change it
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
|
|
|
|||
|
|
@ -107,29 +107,10 @@ class TestConditionalAuthenticationIntegration:
|
|||
# REQUIRE_AUTHENTICATION should be a boolean
|
||||
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||
|
||||
# Currently should be False (optional authentication)
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
|
||||
class TestConditionalAuthenticationEnvironmentVariables:
|
||||
"""Test environment variable handling."""
|
||||
|
||||
def test_require_authentication_default_false(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment - module will see empty environment
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
importlib.invalidate_caches()
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_true(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
|
||||
|
|
@ -145,50 +126,6 @@ class TestConditionalAuthenticationEnvironmentVariables:
|
|||
|
||||
assert REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_false_explicit(self):
|
||||
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
def test_require_authentication_case_insensitive(self):
|
||||
"""Test that environment variable parsing is case insensitive when imported."""
|
||||
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
|
||||
|
||||
for case in test_cases:
|
||||
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
|
||||
# Remove module from cache to force fresh import
|
||||
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Import after patching environment
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
expected = case.lower() == "true"
|
||||
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
|
||||
|
||||
def test_current_require_authentication_value(self):
|
||||
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
|
||||
from cognee.modules.users.methods.get_authenticated_user import (
|
||||
REQUIRE_AUTHENTICATION,
|
||||
)
|
||||
|
||||
# The module-level variable should currently be False (set at import time)
|
||||
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||
assert not REQUIRE_AUTHENTICATION
|
||||
|
||||
|
||||
class TestConditionalAuthenticationEdgeCases:
|
||||
"""Test edge cases and error scenarios."""
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ async def run_procurement_example():
|
|||
for q in questions:
|
||||
print(f"Question: \n{q}")
|
||||
results = await procurement_system.search_memory(q, search_categories=[category])
|
||||
top_answer = results[category][0]
|
||||
top_answer = results[category][0]["search_result"][0]
|
||||
print(f"Answer: \n{top_answer.strip()}\n")
|
||||
research_notes[category].append({"question": q, "answer": top_answer})
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import cognee
|
||||
from cognee import SearchType
|
||||
from cognee.shared.logging_utils import setup_logging, ERROR
|
||||
|
|
@ -8,6 +10,9 @@ from cognee.api.v1.cognify.code_graph_pipeline import run_code_graph_pipeline
|
|||
|
||||
|
||||
async def main(repo_path, include_docs):
|
||||
# Disable permissions feature for this example
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
run_status = False
|
||||
async for run_status in run_code_graph_pipeline(repo_path, include_docs=include_docs):
|
||||
run_status = run_status
|
||||
|
|
|
|||
|
|
@ -67,7 +67,6 @@ async def run_feedback_enrichment_memify(last_n: int = 5):
|
|||
extraction_tasks=extraction_tasks,
|
||||
enrichment_tasks=enrichment_tasks,
|
||||
data=[{}], # A placeholder to prevent fetching the entire graph
|
||||
dataset="feedback_enrichment_minimal",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ async def main():
|
|||
)
|
||||
|
||||
print("Coding rules created by memify:")
|
||||
for coding_rule in coding_rules:
|
||||
for coding_rule in coding_rules[0]["search_result"][0]:
|
||||
print("- " + coding_rule)
|
||||
|
||||
# Visualize new graph with added memify context
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ from cognee.infrastructure.databases.vector.pgvector import (
|
|||
|
||||
|
||||
async def main():
|
||||
# Disable backend access control to migrate relational data
|
||||
os.environ["ENABLE_BACKEND_ACCESS_CONTROL"] = "false"
|
||||
|
||||
# Clean all data stored in Cognee
|
||||
await cognee.prune.prune_data()
|
||||
await cognee.prune.prune_system(metadata=True)
|
||||
|
|
|
|||
|
|
@ -59,14 +59,6 @@ async def main():
|
|||
for result_text in search_results:
|
||||
print(result_text)
|
||||
|
||||
# Example output:
|
||||
# ({'id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'natural language processing', 'description': 'An interdisciplinary subfield of computer science and information retrieval.'}, {'relationship_name': 'is_a_subfield_of', 'source_node_id': UUID('bc338a39-64d6-549a-acec-da60846dd90d'), 'target_node_id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 15, 473137, tzinfo=datetime.timezone.utc)}, {'id': UUID('6218dbab-eb6a-5759-a864-b3419755ffe0'), 'updated_at': datetime.datetime(2024, 11, 21, 12, 23, 1, 211808, tzinfo=datetime.timezone.utc), 'name': 'computer science', 'description': 'The study of computation and information processing.'})
|
||||
# (...)
|
||||
# It represents nodes and relationships in the knowledge graph:
|
||||
# - The first element is the source node (e.g., 'natural language processing').
|
||||
# - The second element is the relationship between nodes (e.g., 'is_a_subfield_of').
|
||||
# - The third element is the target node (e.g., 'computer science').
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = setup_logging(log_level=ERROR)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue