diff --git a/.env.template b/.env.template index 84dc46d1c..ee62f1d3d 100644 --- a/.env.template +++ b/.env.template @@ -124,6 +124,10 @@ ALLOW_HTTP_REQUESTS=True # When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents RAISE_INCREMENTAL_LOADING_ERRORS=True +# When set to True, the Cognee backend will require authentication for requests to the API. +# If you're disabling this, make sure to also disable ENABLE_BACKEND_ACCESS_CONTROL. +REQUIRE_AUTHENTICATION=False + # Set this variable to True to enforce usage of backend access control for Cognee # Note: This is only currently supported by the following databases: # Relational: SQLite, Postgres diff --git a/README.md b/README.md index a7e7f1e05..e618d5bf9 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,9 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github ## Get Started -Get started quickly with a Google Colab notebook , Deepnote notebook or starter repo +Get started quickly with a Google Colab notebook , Deepnote notebook or starter repo + + ## Contributing diff --git a/cognee/api/client.py b/cognee/api/client.py index 215e4a17e..7588638c3 100644 --- a/cognee/api/client.py +++ b/cognee/api/client.py @@ -33,6 +33,7 @@ from cognee.api.v1.users.routers import ( get_users_router, get_visualize_router, ) +from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION logger = get_logger() @@ -110,7 +111,11 @@ def custom_openapi(): }, } - openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}] + if REQUIRE_AUTHENTICATION: + openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}] + + # Remove global security requirement - let individual endpoints specify their own security + # openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}] app.openapi_schema = openapi_schema diff --git a/cognee/api/health.py b/cognee/api/health.py index 0bfbca806..bdb3b1fe3 100644 --- a/cognee/api/health.py +++ b/cognee/api/health.py @@ -53,7 +53,7 @@ class HealthChecker: # Test connection by creating a session session = engine.get_session() if session: - await session.close() + session.close() response_time = int((time.time() - start_time) * 1000) return ComponentHealth( @@ -190,14 +190,13 @@ class HealthChecker: """Check LLM provider health (non-critical).""" start_time = time.time() try: - from cognee.infrastructure.llm.get_llm_client import get_llm_client + from cognee.infrastructure.llm.LLMGateway import LLMGateway from cognee.infrastructure.llm.config import get_llm_config config = get_llm_config() # Test actual API connection with minimal request - client = get_llm_client() - await client.show_prompt("test", "test") + LLMGateway.show_prompt("test", "test") response_time = int((time.time() - start_time) * 1000) return ComponentHealth( diff --git a/cognee/api/v1/datasets/routers/get_datasets_router.py b/cognee/api/v1/datasets/routers/get_datasets_router.py index 8052e3864..ff310e4b4 100644 --- a/cognee/api/v1/datasets/routers/get_datasets_router.py +++ b/cognee/api/v1/datasets/routers/get_datasets_router.py @@ -114,7 +114,8 @@ def get_datasets_router() -> APIRouter: @router.post("", response_model=DatasetDTO) async def create_new_dataset( - dataset_data: DatasetCreationPayload, user: User = Depends(get_authenticated_user) + dataset_data: DatasetCreationPayload, + user: User = Depends(get_authenticated_user), ): """ Create a new dataset or return existing dataset with the same name. diff --git a/cognee/base_config.py b/cognee/base_config.py index aa0b14008..940846128 100644 --- a/cognee/base_config.py +++ b/cognee/base_config.py @@ -1,15 +1,24 @@ import os from typing import Optional from functools import lru_cache -from cognee.root_dir import get_absolute_path +from cognee.root_dir import get_absolute_path, ensure_absolute_path from cognee.modules.observability.observers import Observer from pydantic_settings import BaseSettings, SettingsConfigDict +import pydantic class BaseConfig(BaseSettings): data_root_directory: str = get_absolute_path(".data_storage") system_root_directory: str = get_absolute_path(".cognee_system") monitoring_tool: object = Observer.LANGFUSE + + @pydantic.model_validator(mode="after") + def validate_paths(self): + # Require absolute paths for root directories + self.data_root_directory = ensure_absolute_path(self.data_root_directory) + self.system_root_directory = ensure_absolute_path(self.system_root_directory) + return self + langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY") langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY") langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST") diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index cdc001863..d96de4520 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -6,6 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict import pydantic from pydantic import Field from cognee.base_config import get_base_config +from cognee.root_dir import ensure_absolute_path from cognee.shared.data_models import KnowledgeGraph @@ -51,15 +52,20 @@ class GraphConfig(BaseSettings): @pydantic.model_validator(mode="after") def fill_derived(cls, values): provider = values.graph_database_provider.lower() + base_config = get_base_config() # Set default filename if no filename is provided if not values.graph_filename: values.graph_filename = f"cognee_graph_{provider}" - # Set file path based on graph database provider if no file path is provided - if not values.graph_file_path: - base_config = get_base_config() - + # Handle graph file path + if values.graph_file_path: + # Check if absolute path is provided + values.graph_file_path = ensure_absolute_path( + os.path.join(values.graph_file_path, values.graph_filename) + ) + else: + # Default path databases_directory_path = os.path.join(base_config.system_root_directory, "databases") values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename) diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 07a3d1e05..f8fad473e 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -1,9 +1,11 @@ import os import pydantic +from pathlib import Path from functools import lru_cache from pydantic_settings import BaseSettings, SettingsConfigDict from cognee.base_config import get_base_config +from cognee.root_dir import ensure_absolute_path class VectorConfig(BaseSettings): @@ -11,11 +13,9 @@ class VectorConfig(BaseSettings): Manage the configuration settings for the vector database. Public methods: - - to_dict: Convert the configuration to a dictionary. Instance variables: - - vector_db_url: The URL of the vector database. - vector_db_port: The port for the vector database. - vector_db_key: The key for accessing the vector database. @@ -30,10 +30,17 @@ class VectorConfig(BaseSettings): model_config = SettingsConfigDict(env_file=".env", extra="allow") @pydantic.model_validator(mode="after") - def fill_derived(cls, values): - # Set file path based on graph database provider if no file path is provided - if not values.vector_db_url: - base_config = get_base_config() + def validate_paths(cls, values): + base_config = get_base_config() + + # If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url) + if values.vector_db_url and Path(values.vector_db_url).exists(): + # Relative path to absolute + values.vector_db_url = ensure_absolute_path( + values.vector_db_url, + ) + else: + # Default path databases_directory_path = os.path.join(base_config.system_root_directory, "databases") values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb") diff --git a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py index dc8443459..acb041e76 100644 --- a/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/FastembedEmbeddingEngine.py @@ -4,7 +4,7 @@ from fastembed import TextEmbedding import litellm import os from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine -from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException +from cognee.infrastructure.databases.exceptions import EmbeddingException from cognee.infrastructure.llm.tokenizer.TikToken import ( TikTokenTokenizer, ) diff --git a/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py b/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py index 24312dab1..27688d2c9 100644 --- a/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py +++ b/cognee/infrastructure/databases/vector/embeddings/embedding_rate_limiter.py @@ -250,9 +250,7 @@ def embedding_rate_limit_sync(func): logger.warning(error_msg) # Create a custom embedding rate limit exception - from cognee.infrastructure.databases.exceptions.EmbeddingException import ( - EmbeddingException, - ) + from cognee.infrastructure.databases.exceptions import EmbeddingException raise EmbeddingException(error_msg) @@ -307,9 +305,7 @@ def embedding_rate_limit_async(func): logger.warning(error_msg) # Create a custom embedding rate limit exception - from cognee.infrastructure.databases.exceptions.EmbeddingException import ( - EmbeddingException, - ) + from cognee.infrastructure.databases.exceptions import EmbeddingException raise EmbeddingException(error_msg) diff --git a/cognee/modules/users/methods/__init__.py b/cognee/modules/users/methods/__init__.py index 969615b89..5d45df97b 100644 --- a/cognee/modules/users/methods/__init__.py +++ b/cognee/modules/users/methods/__init__.py @@ -4,4 +4,7 @@ from .delete_user import delete_user from .get_default_user import get_default_user from .get_user_by_email import get_user_by_email from .create_default_user import create_default_user -from .get_authenticated_user import get_authenticated_user +from .get_authenticated_user import ( + get_authenticated_user, + REQUIRE_AUTHENTICATION, +) diff --git a/cognee/modules/users/methods/get_authenticated_user.py b/cognee/modules/users/methods/get_authenticated_user.py index b60ddfe28..a2dd2330e 100644 --- a/cognee/modules/users/methods/get_authenticated_user.py +++ b/cognee/modules/users/methods/get_authenticated_user.py @@ -1,48 +1,42 @@ +import os +from typing import Optional +from fastapi import Depends, HTTPException +from ..models import User from ..get_fastapi_users import get_fastapi_users +from .get_default_user import get_default_user +from cognee.shared.logging_utils import get_logger +logger = get_logger("get_authenticated_user") + +# Check environment variable to determine authentication requirement +REQUIRE_AUTHENTICATION = ( + os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true" + or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true" +) + fastapi_users = get_fastapi_users() -get_authenticated_user = fastapi_users.current_user(active=True) - -# from types import SimpleNamespace - -# from ..get_fastapi_users import get_fastapi_users -# from fastapi import HTTPException, Security -# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -# import os -# import jwt - -# from uuid import UUID - -# fastapi_users = get_fastapi_users() - -# # Allows Swagger to understand authorization type and allow single sign on for the Swagger docs to test backend -# bearer_scheme = HTTPBearer(scheme_name="BearerAuth", description="Paste **Bearer <JWT>**") +_auth_dependency = fastapi_users.current_user(active=True, optional=not REQUIRE_AUTHENTICATION) -# async def get_authenticated_user( -# creds: HTTPAuthorizationCredentials = Security(bearer_scheme), -# ) -> SimpleNamespace: -# """ -# Extract and validate the JWT presented in the Authorization header. -# """ -# if creds is None: # header missing -# raise HTTPException(status_code=401, detail="Not authenticated") +async def get_authenticated_user( + user: Optional[User] = Depends(_auth_dependency), +) -> User: + """ + Get authenticated user with environment-controlled behavior: + - If REQUIRE_AUTHENTICATION=true: Enforces authentication (raises 401 if not authenticated) + - If REQUIRE_AUTHENTICATION=false: Falls back to default user if not authenticated -# if creds.scheme.lower() != "bearer": # shouldn't happen extra guard -# raise HTTPException(status_code=401, detail="Invalid authentication scheme") + Always returns a User object for consistent typing. + """ + if user is None: + # When authentication is optional and user is None, use default user + try: + user = await get_default_user() + except Exception as e: + # Convert any get_default_user failure into a proper HTTP 500 error + logger.error(f"Failed to create default user: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}") -# token = creds.credentials -# try: -# payload = jwt.decode( -# token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"] -# ) - -# auth_data = SimpleNamespace(id=UUID(payload["user_id"])) -# return auth_data - -# except jwt.ExpiredSignatureError: -# raise HTTPException(status_code=401, detail="Token has expired") -# except jwt.InvalidTokenError: -# raise HTTPException(status_code=401, detail="Invalid token") + return user diff --git a/cognee/root_dir.py b/cognee/root_dir.py index 2e21d5ce3..46d8fcb69 100644 --- a/cognee/root_dir.py +++ b/cognee/root_dir.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional ROOT_DIR = Path(__file__).resolve().parent @@ -6,3 +7,21 @@ ROOT_DIR = Path(__file__).resolve().parent def get_absolute_path(path_from_root: str) -> str: absolute_path = ROOT_DIR / path_from_root return str(absolute_path.resolve()) + + +def ensure_absolute_path(path: str) -> str: + """Ensures a path is absolute. + + Args: + path: The path to validate. + + Returns: + Absolute path as string + """ + if path is None: + raise ValueError("Path cannot be None") + path_obj = Path(path).expanduser() + if path_obj.is_absolute(): + return str(path_obj.resolve()) + + raise ValueError(f"Path must be absolute. Got relative path: {path}") diff --git a/cognee/tests/unit/api/__init__.py b/cognee/tests/unit/api/__init__.py new file mode 100644 index 000000000..2b1755712 --- /dev/null +++ b/cognee/tests/unit/api/__init__.py @@ -0,0 +1 @@ +# Test package for API tests diff --git a/cognee/tests/unit/api/test_conditional_authentication_endpoints.py b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py new file mode 100644 index 000000000..2eabee91a --- /dev/null +++ b/cognee/tests/unit/api/test_conditional_authentication_endpoints.py @@ -0,0 +1,246 @@ +import pytest +from unittest.mock import patch, AsyncMock, MagicMock +from uuid import uuid4 +from fastapi.testclient import TestClient +from types import SimpleNamespace +import importlib + +from cognee.api.client import app + + +# Fixtures for reuse across test classes +@pytest.fixture +def mock_default_user(): + """Mock default user for testing.""" + return SimpleNamespace( + id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4() + ) + + +@pytest.fixture +def mock_authenticated_user(): + """Mock authenticated user for testing.""" + from cognee.modules.users.models import User + + return User( + id=uuid4(), + email="auth@example.com", + hashed_password="hashed", + is_active=True, + is_verified=True, + tenant_id=uuid4(), + ) + + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + + +class TestConditionalAuthenticationEndpoints: + """Test that API endpoints work correctly with conditional authentication.""" + + @pytest.fixture + def client(self): + """Create a test client.""" + return TestClient(app) + + def test_health_endpoint_no_auth_required(self, client): + """Test that health endpoint works without authentication.""" + response = client.get("/health") + assert response.status_code in [200, 503] # 503 is also acceptable for health checks + + def test_root_endpoint_no_auth_required(self, client): + """Test that root endpoint works without authentication.""" + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World, I am alive!"} + + @patch( + "cognee.api.client.REQUIRE_AUTHENTICATION", + False, + ) + def test_openapi_schema_no_global_security(self, client): + """Test that OpenAPI schema doesn't require global authentication.""" + response = client.get("/openapi.json") + assert response.status_code == 200 + + schema = response.json() + + # Should not have global security requirement + global_security = schema.get("security", []) + assert global_security == [] + + # But should still have security schemes defined + security_schemes = schema.get("components", {}).get("securitySchemes", {}) + assert "BearerAuth" in security_schemes + assert "CookieAuth" in security_schemes + + @patch("cognee.api.v1.add.add") + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + @patch( + "cognee.api.client.REQUIRE_AUTHENTICATION", + False, + ) + def test_add_endpoint_with_conditional_auth( + self, mock_get_default_user, mock_add, client, mock_default_user + ): + """Test add endpoint works with conditional authentication.""" + mock_get_default_user.return_value = mock_default_user + mock_add.return_value = MagicMock( + model_dump=lambda: {"status": "success", "pipeline_run_id": str(uuid4())} + ) + + # Test file upload without authentication + files = {"data": ("test.txt", b"test content", "text/plain")} + form_data = {"datasetName": "test_dataset"} + + response = client.post("/api/v1/add", files=files, data=form_data) + + assert mock_get_default_user.call_count == 1 + + # Core test: authentication is not required (should not get 401) + assert response.status_code != 401 + + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + @patch( + "cognee.api.client.REQUIRE_AUTHENTICATION", + False, + ) + def test_conditional_authentication_works_with_current_environment( + self, mock_get_default_user, client + ): + """Test that conditional authentication works with the current environment setup.""" + # Since REQUIRE_AUTHENTICATION defaults to "false", we expect endpoints to work without auth + # This tests the actual integration behavior + + mock_get_default_user.return_value = SimpleNamespace( + id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4() + ) + + files = {"data": ("test.txt", b"test content", "text/plain")} + form_data = {"datasetName": "test_dataset"} + + response = client.post("/api/v1/add", files=files, data=form_data) + + assert mock_get_default_user.call_count == 1 + + # Core test: authentication is not required (should not get 401) + assert response.status_code != 401 + # Note: This test verifies conditional authentication works in the current environment + + +class TestConditionalAuthenticationBehavior: + """Test the behavior of conditional authentication across different endpoints.""" + + @pytest.fixture + def client(self): + return TestClient(app) + + @pytest.mark.parametrize( + "endpoint,method", + [ + ("/api/v1/search", "GET"), + ("/api/v1/datasets", "GET"), + ], + ) + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + def test_get_endpoints_work_without_auth( + self, mock_get_default, client, endpoint, method, mock_default_user + ): + """Test that GET endpoints work without authentication (with current environment).""" + mock_get_default.return_value = mock_default_user + + if method == "GET": + response = client.get(endpoint) + elif method == "POST": + response = client.post(endpoint, json={}) + + assert mock_get_default.call_count == 1 + + # Should not return 401 Unauthorized (authentication is optional by default) + assert response.status_code != 401 + + # May return other errors due to missing data/config, but not auth errors + if response.status_code >= 400: + # Check that it's not an authentication error + try: + error_detail = response.json().get("detail", "") + assert "authenticate" not in error_detail.lower() + assert "unauthorized" not in error_detail.lower() + except Exception: + pass # If response is not JSON, that's fine + + gsm_mod = importlib.import_module("cognee.modules.settings.get_settings") + + @patch.object(gsm_mod, "get_vectordb_config") + @patch.object(gsm_mod, "get_llm_config") + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + def test_settings_endpoint_integration( + self, mock_get_default, mock_llm_config, mock_vector_config, client, mock_default_user + ): + """Test that settings endpoint integration works with conditional authentication.""" + mock_get_default.return_value = mock_default_user + + # Mock configurations to avoid validation errors + mock_llm_config.return_value = SimpleNamespace( + llm_provider="openai", + llm_model="gpt-4o", + llm_endpoint=None, + llm_api_version=None, + llm_api_key="test_key_1234567890", + ) + + mock_vector_config.return_value = SimpleNamespace( + vector_db_provider="lancedb", + vector_db_url="localhost:5432", # Must be string, not None + vector_db_key="test_vector_key", + ) + + response = client.get("/api/v1/settings") + + assert mock_get_default.call_count == 1 + + # Core test: authentication is not required (should not get 401) + assert response.status_code != 401 + # Note: This test verifies conditional authentication works for settings endpoint + + +class TestConditionalAuthenticationErrorHandling: + """Test error handling in conditional authentication.""" + + @pytest.fixture + def client(self): + return TestClient(app) + + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + def test_get_default_user_fails(self, mock_get_default, client): + """Test behavior when get_default_user fails (with current environment).""" + mock_get_default.side_effect = Exception("Database connection failed") + + # The error should propagate - either as a 500 error or as an exception + files = {"data": ("test.txt", b"test content", "text/plain")} + form_data = {"datasetName": "test_dataset"} + + # Test that the exception is properly converted to HTTP 500 + response = client.post("/api/v1/add", files=files, data=form_data) + + # Should return HTTP 500 Internal Server Error when get_default_user fails + assert response.status_code == 500 + + # Check that the error message is informative + error_detail = response.json().get("detail", "") + assert "Failed to create default user" in error_detail + # The exact error message may vary depending on the actual database connection + # The important thing is that we get a 500 error when user creation fails + + def test_current_environment_configuration(self): + """Test that current environment configuration is working properly.""" + # This tests the actual module state without trying to change it + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + # Should be a boolean value (the parsing logic works) + assert isinstance(REQUIRE_AUTHENTICATION, bool) + + # In default environment, should be False + assert not REQUIRE_AUTHENTICATION diff --git a/cognee/tests/unit/modules/users/__init__.py b/cognee/tests/unit/modules/users/__init__.py new file mode 100644 index 000000000..a5e9995d3 --- /dev/null +++ b/cognee/tests/unit/modules/users/__init__.py @@ -0,0 +1 @@ +# Test package for user module tests diff --git a/cognee/tests/unit/modules/users/test_conditional_authentication.py b/cognee/tests/unit/modules/users/test_conditional_authentication.py new file mode 100644 index 000000000..c4368d796 --- /dev/null +++ b/cognee/tests/unit/modules/users/test_conditional_authentication.py @@ -0,0 +1,277 @@ +import os +import sys +import pytest +from unittest.mock import AsyncMock, patch +from uuid import uuid4 +from types import SimpleNamespace +import importlib + + +from cognee.modules.users.models import User + + +gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user") + + +class TestConditionalAuthentication: + """Test cases for conditional authentication functionality.""" + + @pytest.mark.asyncio + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_require_authentication_false_no_token_returns_default_user( + self, mock_get_default + ): + """Test that when REQUIRE_AUTHENTICATION=false and no token, returns default user.""" + # Mock the default user + mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True) + mock_get_default.return_value = mock_default_user + + # Use gau_mod.get_authenticated_user instead + + # Test with None user (no authentication) + result = await gau_mod.get_authenticated_user(user=None) + + assert result == mock_default_user + mock_get_default.assert_called_once() + + @pytest.mark.asyncio + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_require_authentication_false_with_valid_user_returns_user( + self, mock_get_default + ): + """Test that when REQUIRE_AUTHENTICATION=false and valid user, returns that user.""" + mock_authenticated_user = User( + id=uuid4(), + email="user@example.com", + hashed_password="hashed", + is_active=True, + is_verified=True, + ) + + # Use gau_mod.get_authenticated_user instead + + # Test with authenticated user + result = await gau_mod.get_authenticated_user(user=mock_authenticated_user) + + assert result == mock_authenticated_user + mock_get_default.assert_not_called() + + @pytest.mark.asyncio + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_require_authentication_true_with_user_returns_user(self, mock_get_default): + """Test that when REQUIRE_AUTHENTICATION=true and user present, returns user.""" + mock_authenticated_user = User( + id=uuid4(), + email="user@example.com", + hashed_password="hashed", + is_active=True, + is_verified=True, + ) + + # Use gau_mod.get_authenticated_user instead + + result = await gau_mod.get_authenticated_user(user=mock_authenticated_user) + + assert result == mock_authenticated_user + + +class TestConditionalAuthenticationIntegration: + """Integration tests that test the full authentication flow.""" + + @pytest.mark.asyncio + async def test_fastapi_users_dependency_creation(self): + """Test that FastAPI Users dependency can be created correctly.""" + from cognee.modules.users.get_fastapi_users import get_fastapi_users + + fastapi_users = get_fastapi_users() + + # Test that we can create optional dependency + optional_dependency = fastapi_users.current_user(optional=True, active=True) + assert callable(optional_dependency) + + # Test that we can create required dependency + required_dependency = fastapi_users.current_user(active=True) # optional=False by default + assert callable(required_dependency) + + @pytest.mark.asyncio + async def test_conditional_authentication_function_exists(self): + """Test that the conditional authentication function can be imported and used.""" + from cognee.modules.users.methods.get_authenticated_user import ( + get_authenticated_user, + REQUIRE_AUTHENTICATION, + ) + + # Should be callable + assert callable(get_authenticated_user) + + # REQUIRE_AUTHENTICATION should be a boolean + assert isinstance(REQUIRE_AUTHENTICATION, bool) + + # Currently should be False (optional authentication) + assert not REQUIRE_AUTHENTICATION + + +class TestConditionalAuthenticationEnvironmentVariables: + """Test environment variable handling.""" + + def test_require_authentication_default_false(self): + """Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars.""" + with patch.dict(os.environ, {}, clear=True): + # Remove module from cache to force fresh import + module_name = "cognee.modules.users.methods.get_authenticated_user" + if module_name in sys.modules: + del sys.modules[module_name] + + # Import after patching environment - module will see empty environment + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + importlib.invalidate_caches() + assert not REQUIRE_AUTHENTICATION + + def test_require_authentication_true(self): + """Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported.""" + with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}): + # Remove module from cache to force fresh import + module_name = "cognee.modules.users.methods.get_authenticated_user" + if module_name in sys.modules: + del sys.modules[module_name] + + # Import after patching environment - module will see REQUIRE_AUTHENTICATION=true + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + assert REQUIRE_AUTHENTICATION + + def test_require_authentication_false_explicit(self): + """Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported.""" + with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}): + # Remove module from cache to force fresh import + module_name = "cognee.modules.users.methods.get_authenticated_user" + if module_name in sys.modules: + del sys.modules[module_name] + + # Import after patching environment - module will see REQUIRE_AUTHENTICATION=false + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + assert not REQUIRE_AUTHENTICATION + + def test_require_authentication_case_insensitive(self): + """Test that environment variable parsing is case insensitive when imported.""" + test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"] + + for case in test_cases: + with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}): + # Remove module from cache to force fresh import + module_name = "cognee.modules.users.methods.get_authenticated_user" + if module_name in sys.modules: + del sys.modules[module_name] + + # Import after patching environment + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + expected = case.lower() == "true" + assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}" + + def test_current_require_authentication_value(self): + """Test that the current REQUIRE_AUTHENTICATION module value is as expected.""" + from cognee.modules.users.methods.get_authenticated_user import ( + REQUIRE_AUTHENTICATION, + ) + + # The module-level variable should currently be False (set at import time) + assert isinstance(REQUIRE_AUTHENTICATION, bool) + assert not REQUIRE_AUTHENTICATION + + +class TestConditionalAuthenticationEdgeCases: + """Test edge cases and error scenarios.""" + + @pytest.mark.asyncio + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_get_default_user_raises_exception(self, mock_get_default): + """Test behavior when get_default_user raises an exception.""" + mock_get_default.side_effect = Exception("Database error") + + # This should propagate the exception + with pytest.raises(Exception, match="Database error"): + await gau_mod.get_authenticated_user(user=None) + + @pytest.mark.asyncio + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_user_type_consistency(self, mock_get_default): + """Test that the function always returns the same type.""" + mock_user = User( + id=uuid4(), + email="test@example.com", + hashed_password="hashed", + is_active=True, + is_verified=True, + ) + + mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True) + mock_get_default.return_value = mock_default_user + + # Test with user + result1 = await gau_mod.get_authenticated_user(user=mock_user) + assert result1 == mock_user + + # Test with None + result2 = await gau_mod.get_authenticated_user(user=None) + assert result2 == mock_default_user + + # Both should have user-like interface + assert hasattr(result1, "id") + assert hasattr(result1, "email") + assert result1.id == mock_user.id + assert result1.email == mock_user.email + assert hasattr(result2, "id") + assert hasattr(result2, "email") + assert result2.id == mock_default_user.id + assert result2.email == mock_default_user.email + + +@pytest.mark.asyncio +class TestAuthenticationScenarios: + """Test specific authentication scenarios that could occur in FastAPI Users.""" + + @patch.object(gau_mod, "get_default_user", new_callable=AsyncMock) + async def test_fallback_to_default_user_scenarios(self, mock_get_default): + """ + Test fallback to default user for all scenarios where FastAPI Users returns None: + - No JWT/Cookie present + - Invalid JWT/Cookie + - Valid JWT but user doesn't exist in database + - Valid JWT but user is inactive (active=True requirement) + + All these scenarios result in FastAPI Users returning None when optional=True, + which should trigger fallback to default user. + """ + mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com") + mock_get_default.return_value = mock_default_user + + # All the above scenarios result in user=None being passed to our function + result = await gau_mod.get_authenticated_user(user=None) + assert result == mock_default_user + mock_get_default.assert_called_once() + + async def test_scenario_valid_active_user(self): + """Scenario: Valid JWT and user exists and is active → returns the user.""" + mock_user = User( + id=uuid4(), + email="active@example.com", + hashed_password="hashed", + is_active=True, + is_verified=True, + ) + + # Use gau_mod.get_authenticated_user instead + + result = await gau_mod.get_authenticated_user(user=mock_user) + assert result == mock_user diff --git a/cognee/tests/unit/processing/utils/utils_test.py b/cognee/tests/unit/processing/utils/utils_test.py index a684df8ed..ca9f8f065 100644 --- a/cognee/tests/unit/processing/utils/utils_test.py +++ b/cognee/tests/unit/processing/utils/utils_test.py @@ -4,8 +4,9 @@ import pytest from unittest.mock import patch, mock_open from io import BytesIO from uuid import uuid4 +from pathlib import Path - +from cognee.root_dir import ensure_absolute_path from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash from cognee.shared.utils import get_anonymous_id @@ -52,3 +53,21 @@ async def test_get_file_content_hash_stream(): expected_hash = hashlib.md5(b"test_data").hexdigest() result = await get_file_content_hash(stream) assert result == expected_hash + + +@pytest.mark.asyncio +async def test_root_dir_absolute_paths(): + """Test absolute path handling in root_dir.py""" + # Test with absolute path + abs_path = "C:/absolute/path" if os.name == "nt" else "/absolute/path" + result = ensure_absolute_path(abs_path) + assert result == str(Path(abs_path).resolve()) + + # Test with relative path (should fail) + rel_path = "relative/path" + with pytest.raises(ValueError, match="must be absolute"): + ensure_absolute_path(rel_path) + + # Test with None path + with pytest.raises(ValueError, match="cannot be None"): + ensure_absolute_path(None)