diff --git a/.env.template b/.env.template
index 84dc46d1c..ee62f1d3d 100644
--- a/.env.template
+++ b/.env.template
@@ -124,6 +124,10 @@ ALLOW_HTTP_REQUESTS=True
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
RAISE_INCREMENTAL_LOADING_ERRORS=True
+# When set to True, the Cognee backend will require authentication for requests to the API.
+# If you're disabling this, make sure to also disable ENABLE_BACKEND_ACCESS_CONTROL.
+REQUIRE_AUTHENTICATION=False
+
# Set this variable to True to enforce usage of backend access control for Cognee
# Note: This is only currently supported by the following databases:
# Relational: SQLite, Postgres
diff --git a/README.md b/README.md
index a7e7f1e05..e618d5bf9 100644
--- a/README.md
+++ b/README.md
@@ -79,7 +79,9 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
## Get Started
-Get started quickly with a Google Colab notebook , Deepnote notebook or starter repo
+Get started quickly with a Google Colab notebook , Deepnote notebook or starter repo
+
+
## Contributing
diff --git a/cognee/api/client.py b/cognee/api/client.py
index 215e4a17e..7588638c3 100644
--- a/cognee/api/client.py
+++ b/cognee/api/client.py
@@ -33,6 +33,7 @@ from cognee.api.v1.users.routers import (
get_users_router,
get_visualize_router,
)
+from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
logger = get_logger()
@@ -110,7 +111,11 @@ def custom_openapi():
},
}
- openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
+ if REQUIRE_AUTHENTICATION:
+ openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
+
+ # Remove global security requirement - let individual endpoints specify their own security
+ # openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
app.openapi_schema = openapi_schema
diff --git a/cognee/api/health.py b/cognee/api/health.py
index 0bfbca806..bdb3b1fe3 100644
--- a/cognee/api/health.py
+++ b/cognee/api/health.py
@@ -53,7 +53,7 @@ class HealthChecker:
# Test connection by creating a session
session = engine.get_session()
if session:
- await session.close()
+ session.close()
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
@@ -190,14 +190,13 @@ class HealthChecker:
"""Check LLM provider health (non-critical)."""
start_time = time.time()
try:
- from cognee.infrastructure.llm.get_llm_client import get_llm_client
+ from cognee.infrastructure.llm.LLMGateway import LLMGateway
from cognee.infrastructure.llm.config import get_llm_config
config = get_llm_config()
# Test actual API connection with minimal request
- client = get_llm_client()
- await client.show_prompt("test", "test")
+ LLMGateway.show_prompt("test", "test")
response_time = int((time.time() - start_time) * 1000)
return ComponentHealth(
diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py
index 8052e3864..ff310e4b4 100644
--- a/cognee/api/v1/datasets/routers/get_datasets_router.py
+++ b/cognee/api/v1/datasets/routers/get_datasets_router.py
@@ -114,7 +114,8 @@ def get_datasets_router() -> APIRouter:
@router.post("", response_model=DatasetDTO)
async def create_new_dataset(
- dataset_data: DatasetCreationPayload, user: User = Depends(get_authenticated_user)
+ dataset_data: DatasetCreationPayload,
+ user: User = Depends(get_authenticated_user),
):
"""
Create a new dataset or return existing dataset with the same name.
diff --git a/cognee/base_config.py b/cognee/base_config.py
index aa0b14008..940846128 100644
--- a/cognee/base_config.py
+++ b/cognee/base_config.py
@@ -1,15 +1,24 @@
import os
from typing import Optional
from functools import lru_cache
-from cognee.root_dir import get_absolute_path
+from cognee.root_dir import get_absolute_path, ensure_absolute_path
from cognee.modules.observability.observers import Observer
from pydantic_settings import BaseSettings, SettingsConfigDict
+import pydantic
class BaseConfig(BaseSettings):
data_root_directory: str = get_absolute_path(".data_storage")
system_root_directory: str = get_absolute_path(".cognee_system")
monitoring_tool: object = Observer.LANGFUSE
+
+ @pydantic.model_validator(mode="after")
+ def validate_paths(self):
+ # Require absolute paths for root directories
+ self.data_root_directory = ensure_absolute_path(self.data_root_directory)
+ self.system_root_directory = ensure_absolute_path(self.system_root_directory)
+ return self
+
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py
index cdc001863..d96de4520 100644
--- a/cognee/infrastructure/databases/graph/config.py
+++ b/cognee/infrastructure/databases/graph/config.py
@@ -6,6 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
import pydantic
from pydantic import Field
from cognee.base_config import get_base_config
+from cognee.root_dir import ensure_absolute_path
from cognee.shared.data_models import KnowledgeGraph
@@ -51,15 +52,20 @@ class GraphConfig(BaseSettings):
@pydantic.model_validator(mode="after")
def fill_derived(cls, values):
provider = values.graph_database_provider.lower()
+ base_config = get_base_config()
# Set default filename if no filename is provided
if not values.graph_filename:
values.graph_filename = f"cognee_graph_{provider}"
- # Set file path based on graph database provider if no file path is provided
- if not values.graph_file_path:
- base_config = get_base_config()
-
+ # Handle graph file path
+ if values.graph_file_path:
+ # Check if absolute path is provided
+ values.graph_file_path = ensure_absolute_path(
+ os.path.join(values.graph_file_path, values.graph_filename)
+ )
+ else:
+ # Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py
index 07a3d1e05..f8fad473e 100644
--- a/cognee/infrastructure/databases/vector/config.py
+++ b/cognee/infrastructure/databases/vector/config.py
@@ -1,9 +1,11 @@
import os
import pydantic
+from pathlib import Path
from functools import lru_cache
from pydantic_settings import BaseSettings, SettingsConfigDict
from cognee.base_config import get_base_config
+from cognee.root_dir import ensure_absolute_path
class VectorConfig(BaseSettings):
@@ -11,11 +13,9 @@ class VectorConfig(BaseSettings):
Manage the configuration settings for the vector database.
Public methods:
-
- to_dict: Convert the configuration to a dictionary.
Instance variables:
-
- vector_db_url: The URL of the vector database.
- vector_db_port: The port for the vector database.
- vector_db_key: The key for accessing the vector database.
@@ -30,10 +30,17 @@ class VectorConfig(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@pydantic.model_validator(mode="after")
- def fill_derived(cls, values):
- # Set file path based on graph database provider if no file path is provided
- if not values.vector_db_url:
- base_config = get_base_config()
+ def validate_paths(cls, values):
+ base_config = get_base_config()
+
+ # If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
+ if values.vector_db_url and Path(values.vector_db_url).exists():
+ # Relative path to absolute
+ values.vector_db_url = ensure_absolute_path(
+ values.vector_db_url,
+ )
+ else:
+ # Default path
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
diff --git a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py
index dc8443459..acb041e76 100644
--- a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py
+++ b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py
@@ -4,7 +4,7 @@ from fastembed import TextEmbedding
import litellm
import os
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
-from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
+from cognee.infrastructure.databases.exceptions import EmbeddingException
from cognee.infrastructure.llm.tokenizer.TikToken import (
TikTokenTokenizer,
)
diff --git a/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py b/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py
index 24312dab1..27688d2c9 100644
--- a/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py
+++ b/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py
@@ -250,9 +250,7 @@ def embedding_rate_limit_sync(func):
logger.warning(error_msg)
# Create a custom embedding rate limit exception
- from cognee.infrastructure.databases.exceptions.EmbeddingException import (
- EmbeddingException,
- )
+ from cognee.infrastructure.databases.exceptions import EmbeddingException
raise EmbeddingException(error_msg)
@@ -307,9 +305,7 @@ def embedding_rate_limit_async(func):
logger.warning(error_msg)
# Create a custom embedding rate limit exception
- from cognee.infrastructure.databases.exceptions.EmbeddingException import (
- EmbeddingException,
- )
+ from cognee.infrastructure.databases.exceptions import EmbeddingException
raise EmbeddingException(error_msg)
diff --git a/cognee/modules/users/methods/__init__.py b/cognee/modules/users/methods/__init__.py
index 969615b89..5d45df97b 100644
--- a/cognee/modules/users/methods/__init__.py
+++ b/cognee/modules/users/methods/__init__.py
@@ -4,4 +4,7 @@ from .delete_user import delete_user
from .get_default_user import get_default_user
from .get_user_by_email import get_user_by_email
from .create_default_user import create_default_user
-from .get_authenticated_user import get_authenticated_user
+from .get_authenticated_user import (
+ get_authenticated_user,
+ REQUIRE_AUTHENTICATION,
+)
diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py
index b60ddfe28..a2dd2330e 100644
--- a/cognee/modules/users/methods/get_authenticated_user.py
+++ b/cognee/modules/users/methods/get_authenticated_user.py
@@ -1,48 +1,42 @@
+import os
+from typing import Optional
+from fastapi import Depends, HTTPException
+from ..models import User
from ..get_fastapi_users import get_fastapi_users
+from .get_default_user import get_default_user
+from cognee.shared.logging_utils import get_logger
+logger = get_logger("get_authenticated_user")
+
+# Check environment variable to determine authentication requirement
+REQUIRE_AUTHENTICATION = (
+ os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
+ or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
+)
+
fastapi_users = get_fastapi_users()
-get_authenticated_user = fastapi_users.current_user(active=True)
-
-# from types import SimpleNamespace
-
-# from ..get_fastapi_users import get_fastapi_users
-# from fastapi import HTTPException, Security
-# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-# import os
-# import jwt
-
-# from uuid import UUID
-
-# fastapi_users = get_fastapi_users()
-
-# # Allows Swagger to understand authorization type and allow single sign on for the Swagger docs to test backend
-# bearer_scheme = HTTPBearer(scheme_name="BearerAuth", description="Paste **Bearer <JWT>**")
+_auth_dependency = fastapi_users.current_user(active=True, optional=not REQUIRE_AUTHENTICATION)
-# async def get_authenticated_user(
-# creds: HTTPAuthorizationCredentials = Security(bearer_scheme),
-# ) -> SimpleNamespace:
-# """
-# Extract and validate the JWT presented in the Authorization header.
-# """
-# if creds is None: # header missing
-# raise HTTPException(status_code=401, detail="Not authenticated")
+async def get_authenticated_user(
+ user: Optional[User] = Depends(_auth_dependency),
+) -> User:
+ """
+ Get authenticated user with environment-controlled behavior:
+ - If REQUIRE_AUTHENTICATION=true: Enforces authentication (raises 401 if not authenticated)
+ - If REQUIRE_AUTHENTICATION=false: Falls back to default user if not authenticated
-# if creds.scheme.lower() != "bearer": # shouldn't happen extra guard
-# raise HTTPException(status_code=401, detail="Invalid authentication scheme")
+ Always returns a User object for consistent typing.
+ """
+ if user is None:
+ # When authentication is optional and user is None, use default user
+ try:
+ user = await get_default_user()
+ except Exception as e:
+ # Convert any get_default_user failure into a proper HTTP 500 error
+ logger.error(f"Failed to create default user: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}")
-# token = creds.credentials
-# try:
-# payload = jwt.decode(
-# token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"]
-# )
-
-# auth_data = SimpleNamespace(id=UUID(payload["user_id"]))
-# return auth_data
-
-# except jwt.ExpiredSignatureError:
-# raise HTTPException(status_code=401, detail="Token has expired")
-# except jwt.InvalidTokenError:
-# raise HTTPException(status_code=401, detail="Invalid token")
+ return user
diff --git a/cognee/root_dir.py b/cognee/root_dir.py
index 2e21d5ce3..46d8fcb69 100644
--- a/cognee/root_dir.py
+++ b/cognee/root_dir.py
@@ -1,4 +1,5 @@
from pathlib import Path
+from typing import Optional
ROOT_DIR = Path(__file__).resolve().parent
@@ -6,3 +7,21 @@ ROOT_DIR = Path(__file__).resolve().parent
def get_absolute_path(path_from_root: str) -> str:
absolute_path = ROOT_DIR / path_from_root
return str(absolute_path.resolve())
+
+
+def ensure_absolute_path(path: str) -> str:
+ """Ensures a path is absolute.
+
+ Args:
+ path: The path to validate.
+
+ Returns:
+ Absolute path as string
+ """
+ if path is None:
+ raise ValueError("Path cannot be None")
+ path_obj = Path(path).expanduser()
+ if path_obj.is_absolute():
+ return str(path_obj.resolve())
+
+ raise ValueError(f"Path must be absolute. Got relative path: {path}")
diff --git a/cognee/tests/unit/api/__init__.py b/cognee/tests/unit/api/__init__.py
new file mode 100644
index 000000000..2b1755712
--- /dev/null
+++ b/cognee/tests/unit/api/__init__.py
@@ -0,0 +1 @@
+# Test package for API tests
diff --git a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py
new file mode 100644
index 000000000..2eabee91a
--- /dev/null
+++ b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py
@@ -0,0 +1,246 @@
+import pytest
+from unittest.mock import patch, AsyncMock, MagicMock
+from uuid import uuid4
+from fastapi.testclient import TestClient
+from types import SimpleNamespace
+import importlib
+
+from cognee.api.client import app
+
+
+# Fixtures for reuse across test classes
+@pytest.fixture
+def mock_default_user():
+ """Mock default user for testing."""
+ return SimpleNamespace(
+ id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4()
+ )
+
+
+@pytest.fixture
+def mock_authenticated_user():
+ """Mock authenticated user for testing."""
+ from cognee.modules.users.models import User
+
+ return User(
+ id=uuid4(),
+ email="auth@example.com",
+ hashed_password="hashed",
+ is_active=True,
+ is_verified=True,
+ tenant_id=uuid4(),
+ )
+
+
+gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
+
+
+class TestConditionalAuthenticationEndpoints:
+ """Test that API endpoints work correctly with conditional authentication."""
+
+ @pytest.fixture
+ def client(self):
+ """Create a test client."""
+ return TestClient(app)
+
+ def test_health_endpoint_no_auth_required(self, client):
+ """Test that health endpoint works without authentication."""
+ response = client.get("/health")
+ assert response.status_code in [200, 503] # 503 is also acceptable for health checks
+
+ def test_root_endpoint_no_auth_required(self, client):
+ """Test that root endpoint works without authentication."""
+ response = client.get("/")
+ assert response.status_code == 200
+ assert response.json() == {"message": "Hello, World, I am alive!"}
+
+ @patch(
+ "cognee.api.client.REQUIRE_AUTHENTICATION",
+ False,
+ )
+ def test_openapi_schema_no_global_security(self, client):
+ """Test that OpenAPI schema doesn't require global authentication."""
+ response = client.get("/openapi.json")
+ assert response.status_code == 200
+
+ schema = response.json()
+
+ # Should not have global security requirement
+ global_security = schema.get("security", [])
+ assert global_security == []
+
+ # But should still have security schemes defined
+ security_schemes = schema.get("components", {}).get("securitySchemes", {})
+ assert "BearerAuth" in security_schemes
+ assert "CookieAuth" in security_schemes
+
+ @patch("cognee.api.v1.add.add")
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ @patch(
+ "cognee.api.client.REQUIRE_AUTHENTICATION",
+ False,
+ )
+ def test_add_endpoint_with_conditional_auth(
+ self, mock_get_default_user, mock_add, client, mock_default_user
+ ):
+ """Test add endpoint works with conditional authentication."""
+ mock_get_default_user.return_value = mock_default_user
+ mock_add.return_value = MagicMock(
+ model_dump=lambda: {"status": "success", "pipeline_run_id": str(uuid4())}
+ )
+
+ # Test file upload without authentication
+ files = {"data": ("test.txt", b"test content", "text/plain")}
+ form_data = {"datasetName": "test_dataset"}
+
+ response = client.post("/api/v1/add", files=files, data=form_data)
+
+ assert mock_get_default_user.call_count == 1
+
+ # Core test: authentication is not required (should not get 401)
+ assert response.status_code != 401
+
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ @patch(
+ "cognee.api.client.REQUIRE_AUTHENTICATION",
+ False,
+ )
+ def test_conditional_authentication_works_with_current_environment(
+ self, mock_get_default_user, client
+ ):
+ """Test that conditional authentication works with the current environment setup."""
+ # Since REQUIRE_AUTHENTICATION defaults to "false", we expect endpoints to work without auth
+ # This tests the actual integration behavior
+
+ mock_get_default_user.return_value = SimpleNamespace(
+ id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4()
+ )
+
+ files = {"data": ("test.txt", b"test content", "text/plain")}
+ form_data = {"datasetName": "test_dataset"}
+
+ response = client.post("/api/v1/add", files=files, data=form_data)
+
+ assert mock_get_default_user.call_count == 1
+
+ # Core test: authentication is not required (should not get 401)
+ assert response.status_code != 401
+ # Note: This test verifies conditional authentication works in the current environment
+
+
+class TestConditionalAuthenticationBehavior:
+ """Test the behavior of conditional authentication across different endpoints."""
+
+ @pytest.fixture
+ def client(self):
+ return TestClient(app)
+
+ @pytest.mark.parametrize(
+ "endpoint,method",
+ [
+ ("/api/v1/search", "GET"),
+ ("/api/v1/datasets", "GET"),
+ ],
+ )
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ def test_get_endpoints_work_without_auth(
+ self, mock_get_default, client, endpoint, method, mock_default_user
+ ):
+ """Test that GET endpoints work without authentication (with current environment)."""
+ mock_get_default.return_value = mock_default_user
+
+ if method == "GET":
+ response = client.get(endpoint)
+ elif method == "POST":
+ response = client.post(endpoint, json={})
+
+ assert mock_get_default.call_count == 1
+
+ # Should not return 401 Unauthorized (authentication is optional by default)
+ assert response.status_code != 401
+
+ # May return other errors due to missing data/config, but not auth errors
+ if response.status_code >= 400:
+ # Check that it's not an authentication error
+ try:
+ error_detail = response.json().get("detail", "")
+ assert "authenticate" not in error_detail.lower()
+ assert "unauthorized" not in error_detail.lower()
+ except Exception:
+ pass # If response is not JSON, that's fine
+
+ gsm_mod = importlib.import_module("cognee.modules.settings.get_settings")
+
+ @patch.object(gsm_mod, "get_vectordb_config")
+ @patch.object(gsm_mod, "get_llm_config")
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ def test_settings_endpoint_integration(
+ self, mock_get_default, mock_llm_config, mock_vector_config, client, mock_default_user
+ ):
+ """Test that settings endpoint integration works with conditional authentication."""
+ mock_get_default.return_value = mock_default_user
+
+ # Mock configurations to avoid validation errors
+ mock_llm_config.return_value = SimpleNamespace(
+ llm_provider="openai",
+ llm_model="gpt-4o",
+ llm_endpoint=None,
+ llm_api_version=None,
+ llm_api_key="test_key_1234567890",
+ )
+
+ mock_vector_config.return_value = SimpleNamespace(
+ vector_db_provider="lancedb",
+ vector_db_url="localhost:5432", # Must be string, not None
+ vector_db_key="test_vector_key",
+ )
+
+ response = client.get("/api/v1/settings")
+
+ assert mock_get_default.call_count == 1
+
+ # Core test: authentication is not required (should not get 401)
+ assert response.status_code != 401
+ # Note: This test verifies conditional authentication works for settings endpoint
+
+
+class TestConditionalAuthenticationErrorHandling:
+ """Test error handling in conditional authentication."""
+
+ @pytest.fixture
+ def client(self):
+ return TestClient(app)
+
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ def test_get_default_user_fails(self, mock_get_default, client):
+ """Test behavior when get_default_user fails (with current environment)."""
+ mock_get_default.side_effect = Exception("Database connection failed")
+
+ # The error should propagate - either as a 500 error or as an exception
+ files = {"data": ("test.txt", b"test content", "text/plain")}
+ form_data = {"datasetName": "test_dataset"}
+
+ # Test that the exception is properly converted to HTTP 500
+ response = client.post("/api/v1/add", files=files, data=form_data)
+
+ # Should return HTTP 500 Internal Server Error when get_default_user fails
+ assert response.status_code == 500
+
+ # Check that the error message is informative
+ error_detail = response.json().get("detail", "")
+ assert "Failed to create default user" in error_detail
+ # The exact error message may vary depending on the actual database connection
+ # The important thing is that we get a 500 error when user creation fails
+
+ def test_current_environment_configuration(self):
+ """Test that current environment configuration is working properly."""
+ # This tests the actual module state without trying to change it
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ # Should be a boolean value (the parsing logic works)
+ assert isinstance(REQUIRE_AUTHENTICATION, bool)
+
+ # In default environment, should be False
+ assert not REQUIRE_AUTHENTICATION
diff --git a/cognee/tests/unit/modules/users/__init__.py b/cognee/tests/unit/modules/users/__init__.py
new file mode 100644
index 000000000..a5e9995d3
--- /dev/null
+++ b/cognee/tests/unit/modules/users/__init__.py
@@ -0,0 +1 @@
+# Test package for user module tests
diff --git a/cognee/tests/unit/modules/users/test_conditional_authentication.py b/cognee/tests/unit/modules/users/test_conditional_authentication.py
new file mode 100644
index 000000000..c4368d796
--- /dev/null
+++ b/cognee/tests/unit/modules/users/test_conditional_authentication.py
@@ -0,0 +1,277 @@
+import os
+import sys
+import pytest
+from unittest.mock import AsyncMock, patch
+from uuid import uuid4
+from types import SimpleNamespace
+import importlib
+
+
+from cognee.modules.users.models import User
+
+
+gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
+
+
+class TestConditionalAuthentication:
+ """Test cases for conditional authentication functionality."""
+
+ @pytest.mark.asyncio
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_require_authentication_false_no_token_returns_default_user(
+ self, mock_get_default
+ ):
+ """Test that when REQUIRE_AUTHENTICATION=false and no token, returns default user."""
+ # Mock the default user
+ mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True)
+ mock_get_default.return_value = mock_default_user
+
+ # Use gau_mod.get_authenticated_user instead
+
+ # Test with None user (no authentication)
+ result = await gau_mod.get_authenticated_user(user=None)
+
+ assert result == mock_default_user
+ mock_get_default.assert_called_once()
+
+ @pytest.mark.asyncio
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_require_authentication_false_with_valid_user_returns_user(
+ self, mock_get_default
+ ):
+ """Test that when REQUIRE_AUTHENTICATION=false and valid user, returns that user."""
+ mock_authenticated_user = User(
+ id=uuid4(),
+ email="user@example.com",
+ hashed_password="hashed",
+ is_active=True,
+ is_verified=True,
+ )
+
+ # Use gau_mod.get_authenticated_user instead
+
+ # Test with authenticated user
+ result = await gau_mod.get_authenticated_user(user=mock_authenticated_user)
+
+ assert result == mock_authenticated_user
+ mock_get_default.assert_not_called()
+
+ @pytest.mark.asyncio
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_require_authentication_true_with_user_returns_user(self, mock_get_default):
+ """Test that when REQUIRE_AUTHENTICATION=true and user present, returns user."""
+ mock_authenticated_user = User(
+ id=uuid4(),
+ email="user@example.com",
+ hashed_password="hashed",
+ is_active=True,
+ is_verified=True,
+ )
+
+ # Use gau_mod.get_authenticated_user instead
+
+ result = await gau_mod.get_authenticated_user(user=mock_authenticated_user)
+
+ assert result == mock_authenticated_user
+
+
+class TestConditionalAuthenticationIntegration:
+ """Integration tests that test the full authentication flow."""
+
+ @pytest.mark.asyncio
+ async def test_fastapi_users_dependency_creation(self):
+ """Test that FastAPI Users dependency can be created correctly."""
+ from cognee.modules.users.get_fastapi_users import get_fastapi_users
+
+ fastapi_users = get_fastapi_users()
+
+ # Test that we can create optional dependency
+ optional_dependency = fastapi_users.current_user(optional=True, active=True)
+ assert callable(optional_dependency)
+
+ # Test that we can create required dependency
+ required_dependency = fastapi_users.current_user(active=True) # optional=False by default
+ assert callable(required_dependency)
+
+ @pytest.mark.asyncio
+ async def test_conditional_authentication_function_exists(self):
+ """Test that the conditional authentication function can be imported and used."""
+ from cognee.modules.users.methods.get_authenticated_user import (
+ get_authenticated_user,
+ REQUIRE_AUTHENTICATION,
+ )
+
+ # Should be callable
+ assert callable(get_authenticated_user)
+
+ # REQUIRE_AUTHENTICATION should be a boolean
+ assert isinstance(REQUIRE_AUTHENTICATION, bool)
+
+ # Currently should be False (optional authentication)
+ assert not REQUIRE_AUTHENTICATION
+
+
+class TestConditionalAuthenticationEnvironmentVariables:
+ """Test environment variable handling."""
+
+ def test_require_authentication_default_false(self):
+ """Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
+ with patch.dict(os.environ, {}, clear=True):
+ # Remove module from cache to force fresh import
+ module_name = "cognee.modules.users.methods.get_authenticated_user"
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ # Import after patching environment - module will see empty environment
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ importlib.invalidate_caches()
+ assert not REQUIRE_AUTHENTICATION
+
+ def test_require_authentication_true(self):
+ """Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
+ with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
+ # Remove module from cache to force fresh import
+ module_name = "cognee.modules.users.methods.get_authenticated_user"
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ # Import after patching environment - module will see REQUIRE_AUTHENTICATION=true
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ assert REQUIRE_AUTHENTICATION
+
+ def test_require_authentication_false_explicit(self):
+ """Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
+ with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
+ # Remove module from cache to force fresh import
+ module_name = "cognee.modules.users.methods.get_authenticated_user"
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ # Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ assert not REQUIRE_AUTHENTICATION
+
+ def test_require_authentication_case_insensitive(self):
+ """Test that environment variable parsing is case insensitive when imported."""
+ test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
+
+ for case in test_cases:
+ with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
+ # Remove module from cache to force fresh import
+ module_name = "cognee.modules.users.methods.get_authenticated_user"
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ # Import after patching environment
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ expected = case.lower() == "true"
+ assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
+
+ def test_current_require_authentication_value(self):
+ """Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
+ from cognee.modules.users.methods.get_authenticated_user import (
+ REQUIRE_AUTHENTICATION,
+ )
+
+ # The module-level variable should currently be False (set at import time)
+ assert isinstance(REQUIRE_AUTHENTICATION, bool)
+ assert not REQUIRE_AUTHENTICATION
+
+
+class TestConditionalAuthenticationEdgeCases:
+ """Test edge cases and error scenarios."""
+
+ @pytest.mark.asyncio
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_get_default_user_raises_exception(self, mock_get_default):
+ """Test behavior when get_default_user raises an exception."""
+ mock_get_default.side_effect = Exception("Database error")
+
+ # This should propagate the exception
+ with pytest.raises(Exception, match="Database error"):
+ await gau_mod.get_authenticated_user(user=None)
+
+ @pytest.mark.asyncio
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_user_type_consistency(self, mock_get_default):
+ """Test that the function always returns the same type."""
+ mock_user = User(
+ id=uuid4(),
+ email="test@example.com",
+ hashed_password="hashed",
+ is_active=True,
+ is_verified=True,
+ )
+
+ mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True)
+ mock_get_default.return_value = mock_default_user
+
+ # Test with user
+ result1 = await gau_mod.get_authenticated_user(user=mock_user)
+ assert result1 == mock_user
+
+ # Test with None
+ result2 = await gau_mod.get_authenticated_user(user=None)
+ assert result2 == mock_default_user
+
+ # Both should have user-like interface
+ assert hasattr(result1, "id")
+ assert hasattr(result1, "email")
+ assert result1.id == mock_user.id
+ assert result1.email == mock_user.email
+ assert hasattr(result2, "id")
+ assert hasattr(result2, "email")
+ assert result2.id == mock_default_user.id
+ assert result2.email == mock_default_user.email
+
+
+@pytest.mark.asyncio
+class TestAuthenticationScenarios:
+ """Test specific authentication scenarios that could occur in FastAPI Users."""
+
+ @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
+ async def test_fallback_to_default_user_scenarios(self, mock_get_default):
+ """
+ Test fallback to default user for all scenarios where FastAPI Users returns None:
+ - No JWT/Cookie present
+ - Invalid JWT/Cookie
+ - Valid JWT but user doesn't exist in database
+ - Valid JWT but user is inactive (active=True requirement)
+
+ All these scenarios result in FastAPI Users returning None when optional=True,
+ which should trigger fallback to default user.
+ """
+ mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com")
+ mock_get_default.return_value = mock_default_user
+
+ # All the above scenarios result in user=None being passed to our function
+ result = await gau_mod.get_authenticated_user(user=None)
+ assert result == mock_default_user
+ mock_get_default.assert_called_once()
+
+ async def test_scenario_valid_active_user(self):
+ """Scenario: Valid JWT and user exists and is active → returns the user."""
+ mock_user = User(
+ id=uuid4(),
+ email="active@example.com",
+ hashed_password="hashed",
+ is_active=True,
+ is_verified=True,
+ )
+
+ # Use gau_mod.get_authenticated_user instead
+
+ result = await gau_mod.get_authenticated_user(user=mock_user)
+ assert result == mock_user
diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py
index a684df8ed..ca9f8f065 100644
--- a/cognee/tests/unit/processing/utils/utils_test.py
+++ b/cognee/tests/unit/processing/utils/utils_test.py
@@ -4,8 +4,9 @@ import pytest
from unittest.mock import patch, mock_open
from io import BytesIO
from uuid import uuid4
+from pathlib import Path
-
+from cognee.root_dir import ensure_absolute_path
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
from cognee.shared.utils import get_anonymous_id
@@ -52,3 +53,21 @@ async def test_get_file_content_hash_stream():
expected_hash = hashlib.md5(b"test_data").hexdigest()
result = await get_file_content_hash(stream)
assert result == expected_hash
+
+
+@pytest.mark.asyncio
+async def test_root_dir_absolute_paths():
+ """Test absolute path handling in root_dir.py"""
+ # Test with absolute path
+ abs_path = "C:/absolute/path" if os.name == "nt" else "/absolute/path"
+ result = ensure_absolute_path(abs_path)
+ assert result == str(Path(abs_path).resolve())
+
+ # Test with relative path (should fail)
+ rel_path = "relative/path"
+ with pytest.raises(ValueError, match="must be absolute"):
+ ensure_absolute_path(rel_path)
+
+ # Test with None path
+ with pytest.raises(ValueError, match="cannot be None"):
+ ensure_absolute_path(None)