Merge branch 'dev' into update-low-level
This commit is contained in:
commit
ae56b09434
18 changed files with 655 additions and 66 deletions
|
|
@ -124,6 +124,10 @@ ALLOW_HTTP_REQUESTS=True
|
||||||
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
# When set to False errors during data processing will be returned as info but not raised to allow handling of faulty documents
|
||||||
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
RAISE_INCREMENTAL_LOADING_ERRORS=True
|
||||||
|
|
||||||
|
# When set to True, the Cognee backend will require authentication for requests to the API.
|
||||||
|
# If you're disabling this, make sure to also disable ENABLE_BACKEND_ACCESS_CONTROL.
|
||||||
|
REQUIRE_AUTHENTICATION=False
|
||||||
|
|
||||||
# Set this variable to True to enforce usage of backend access control for Cognee
|
# Set this variable to True to enforce usage of backend access control for Cognee
|
||||||
# Note: This is only currently supported by the following databases:
|
# Note: This is only currently supported by the following databases:
|
||||||
# Relational: SQLite, Postgres
|
# Relational: SQLite, Postgres
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,9 @@ More on [use-cases](https://docs.cognee.ai/use-cases) and [evals](https://github
|
||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee-starter">starter repo</a>
|
Get started quickly with a Google Colab <a href="https://colab.research.google.com/drive/1jHbWVypDgCLwjE71GSXhRL3YxYhCZzG1?usp=sharing">notebook</a> , <a href="https://deepnote.com/workspace/cognee-382213d0-0444-4c89-8265-13770e333c02/project/cognee-demo-78ffacb9-5832-4611-bb1a-560386068b30/notebook/Notebook-1-75b24cda566d4c24ab348f7150792601?utm_source=share-modal&utm_medium=product-shared-content&utm_campaign=notebook&utm_content=78ffacb9-5832-4611-bb1a-560386068b30">Deepnote notebook</a> or <a href="https://github.com/topoteretes/cognee/tree/main/cognee-starter-kit">starter repo</a>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ from cognee.api.v1.users.routers import (
|
||||||
get_users_router,
|
get_users_router,
|
||||||
get_visualize_router,
|
get_visualize_router,
|
||||||
)
|
)
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
@ -110,7 +111,11 @@ def custom_openapi():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
|
if REQUIRE_AUTHENTICATION:
|
||||||
|
openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
|
||||||
|
|
||||||
|
# Remove global security requirement - let individual endpoints specify their own security
|
||||||
|
# openapi_schema["security"] = [{"BearerAuth": []}, {"CookieAuth": []}]
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ class HealthChecker:
|
||||||
# Test connection by creating a session
|
# Test connection by creating a session
|
||||||
session = engine.get_session()
|
session = engine.get_session()
|
||||||
if session:
|
if session:
|
||||||
await session.close()
|
session.close()
|
||||||
|
|
||||||
response_time = int((time.time() - start_time) * 1000)
|
response_time = int((time.time() - start_time) * 1000)
|
||||||
return ComponentHealth(
|
return ComponentHealth(
|
||||||
|
|
@ -190,14 +190,13 @@ class HealthChecker:
|
||||||
"""Check LLM provider health (non-critical)."""
|
"""Check LLM provider health (non-critical)."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
from cognee.infrastructure.llm.LLMGateway import LLMGateway
|
||||||
from cognee.infrastructure.llm.config import get_llm_config
|
from cognee.infrastructure.llm.config import get_llm_config
|
||||||
|
|
||||||
config = get_llm_config()
|
config = get_llm_config()
|
||||||
|
|
||||||
# Test actual API connection with minimal request
|
# Test actual API connection with minimal request
|
||||||
client = get_llm_client()
|
LLMGateway.show_prompt("test", "test")
|
||||||
await client.show_prompt("test", "test")
|
|
||||||
|
|
||||||
response_time = int((time.time() - start_time) * 1000)
|
response_time = int((time.time() - start_time) * 1000)
|
||||||
return ComponentHealth(
|
return ComponentHealth(
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,8 @@ def get_datasets_router() -> APIRouter:
|
||||||
|
|
||||||
@router.post("", response_model=DatasetDTO)
|
@router.post("", response_model=DatasetDTO)
|
||||||
async def create_new_dataset(
|
async def create_new_dataset(
|
||||||
dataset_data: DatasetCreationPayload, user: User = Depends(get_authenticated_user)
|
dataset_data: DatasetCreationPayload,
|
||||||
|
user: User = Depends(get_authenticated_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new dataset or return existing dataset with the same name.
|
Create a new dataset or return existing dataset with the same name.
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,24 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from cognee.root_dir import get_absolute_path
|
from cognee.root_dir import get_absolute_path, ensure_absolute_path
|
||||||
from cognee.modules.observability.observers import Observer
|
from cognee.modules.observability.observers import Observer
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseSettings):
|
class BaseConfig(BaseSettings):
|
||||||
data_root_directory: str = get_absolute_path(".data_storage")
|
data_root_directory: str = get_absolute_path(".data_storage")
|
||||||
system_root_directory: str = get_absolute_path(".cognee_system")
|
system_root_directory: str = get_absolute_path(".cognee_system")
|
||||||
monitoring_tool: object = Observer.LANGFUSE
|
monitoring_tool: object = Observer.LANGFUSE
|
||||||
|
|
||||||
|
@pydantic.model_validator(mode="after")
|
||||||
|
def validate_paths(self):
|
||||||
|
# Require absolute paths for root directories
|
||||||
|
self.data_root_directory = ensure_absolute_path(self.data_root_directory)
|
||||||
|
self.system_root_directory = ensure_absolute_path(self.system_root_directory)
|
||||||
|
return self
|
||||||
|
|
||||||
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
langfuse_public_key: Optional[str] = os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||||
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
|
langfuse_secret_key: Optional[str] = os.getenv("LANGFUSE_SECRET_KEY")
|
||||||
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
|
langfuse_host: Optional[str] = os.getenv("LANGFUSE_HOST")
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.root_dir import ensure_absolute_path
|
||||||
from cognee.shared.data_models import KnowledgeGraph
|
from cognee.shared.data_models import KnowledgeGraph
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,15 +52,20 @@ class GraphConfig(BaseSettings):
|
||||||
@pydantic.model_validator(mode="after")
|
@pydantic.model_validator(mode="after")
|
||||||
def fill_derived(cls, values):
|
def fill_derived(cls, values):
|
||||||
provider = values.graph_database_provider.lower()
|
provider = values.graph_database_provider.lower()
|
||||||
|
base_config = get_base_config()
|
||||||
|
|
||||||
# Set default filename if no filename is provided
|
# Set default filename if no filename is provided
|
||||||
if not values.graph_filename:
|
if not values.graph_filename:
|
||||||
values.graph_filename = f"cognee_graph_{provider}"
|
values.graph_filename = f"cognee_graph_{provider}"
|
||||||
|
|
||||||
# Set file path based on graph database provider if no file path is provided
|
# Handle graph file path
|
||||||
if not values.graph_file_path:
|
if values.graph_file_path:
|
||||||
base_config = get_base_config()
|
# Check if absolute path is provided
|
||||||
|
values.graph_file_path = ensure_absolute_path(
|
||||||
|
os.path.join(values.graph_file_path, values.graph_filename)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default path
|
||||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||||
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
values.graph_file_path = os.path.join(databases_directory_path, values.graph_filename)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import pydantic
|
import pydantic
|
||||||
|
from pathlib import Path
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
from cognee.base_config import get_base_config
|
from cognee.base_config import get_base_config
|
||||||
|
from cognee.root_dir import ensure_absolute_path
|
||||||
|
|
||||||
|
|
||||||
class VectorConfig(BaseSettings):
|
class VectorConfig(BaseSettings):
|
||||||
|
|
@ -11,11 +13,9 @@ class VectorConfig(BaseSettings):
|
||||||
Manage the configuration settings for the vector database.
|
Manage the configuration settings for the vector database.
|
||||||
|
|
||||||
Public methods:
|
Public methods:
|
||||||
|
|
||||||
- to_dict: Convert the configuration to a dictionary.
|
- to_dict: Convert the configuration to a dictionary.
|
||||||
|
|
||||||
Instance variables:
|
Instance variables:
|
||||||
|
|
||||||
- vector_db_url: The URL of the vector database.
|
- vector_db_url: The URL of the vector database.
|
||||||
- vector_db_port: The port for the vector database.
|
- vector_db_port: The port for the vector database.
|
||||||
- vector_db_key: The key for accessing the vector database.
|
- vector_db_key: The key for accessing the vector database.
|
||||||
|
|
@ -30,10 +30,17 @@ class VectorConfig(BaseSettings):
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
@pydantic.model_validator(mode="after")
|
@pydantic.model_validator(mode="after")
|
||||||
def fill_derived(cls, values):
|
def validate_paths(cls, values):
|
||||||
# Set file path based on graph database provider if no file path is provided
|
base_config = get_base_config()
|
||||||
if not values.vector_db_url:
|
|
||||||
base_config = get_base_config()
|
# If vector_db_url is provided and is not a path skip checking if path is absolute (as it can also be a url)
|
||||||
|
if values.vector_db_url and Path(values.vector_db_url).exists():
|
||||||
|
# Relative path to absolute
|
||||||
|
values.vector_db_url = ensure_absolute_path(
|
||||||
|
values.vector_db_url,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default path
|
||||||
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
databases_directory_path = os.path.join(base_config.system_root_directory, "databases")
|
||||||
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
values.vector_db_url = os.path.join(databases_directory_path, "cognee.lancedb")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from fastembed import TextEmbedding
|
||||||
import litellm
|
import litellm
|
||||||
import os
|
import os
|
||||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||||
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
from cognee.infrastructure.llm.tokenizer.TikToken import (
|
||||||
TikTokenTokenizer,
|
TikTokenTokenizer,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -250,9 +250,7 @@ def embedding_rate_limit_sync(func):
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
|
|
||||||
# Create a custom embedding rate limit exception
|
# Create a custom embedding rate limit exception
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||||
EmbeddingException,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise EmbeddingException(error_msg)
|
raise EmbeddingException(error_msg)
|
||||||
|
|
||||||
|
|
@ -307,9 +305,7 @@ def embedding_rate_limit_async(func):
|
||||||
logger.warning(error_msg)
|
logger.warning(error_msg)
|
||||||
|
|
||||||
# Create a custom embedding rate limit exception
|
# Create a custom embedding rate limit exception
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import (
|
from cognee.infrastructure.databases.exceptions import EmbeddingException
|
||||||
EmbeddingException,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise EmbeddingException(error_msg)
|
raise EmbeddingException(error_msg)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,4 +4,7 @@ from .delete_user import delete_user
|
||||||
from .get_default_user import get_default_user
|
from .get_default_user import get_default_user
|
||||||
from .get_user_by_email import get_user_by_email
|
from .get_user_by_email import get_user_by_email
|
||||||
from .create_default_user import create_default_user
|
from .create_default_user import create_default_user
|
||||||
from .get_authenticated_user import get_authenticated_user
|
from .get_authenticated_user import (
|
||||||
|
get_authenticated_user,
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,42 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
from ..models import User
|
||||||
from ..get_fastapi_users import get_fastapi_users
|
from ..get_fastapi_users import get_fastapi_users
|
||||||
|
from .get_default_user import get_default_user
|
||||||
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("get_authenticated_user")
|
||||||
|
|
||||||
|
# Check environment variable to determine authentication requirement
|
||||||
|
REQUIRE_AUTHENTICATION = (
|
||||||
|
os.getenv("REQUIRE_AUTHENTICATION", "false").lower() == "true"
|
||||||
|
or os.getenv("ENABLE_BACKEND_ACCESS_CONTROL", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
fastapi_users = get_fastapi_users()
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
||||||
get_authenticated_user = fastapi_users.current_user(active=True)
|
_auth_dependency = fastapi_users.current_user(active=True, optional=not REQUIRE_AUTHENTICATION)
|
||||||
|
|
||||||
# from types import SimpleNamespace
|
|
||||||
|
|
||||||
# from ..get_fastapi_users import get_fastapi_users
|
|
||||||
# from fastapi import HTTPException, Security
|
|
||||||
# from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
||||||
# import os
|
|
||||||
# import jwt
|
|
||||||
|
|
||||||
# from uuid import UUID
|
|
||||||
|
|
||||||
# fastapi_users = get_fastapi_users()
|
|
||||||
|
|
||||||
# # Allows Swagger to understand authorization type and allow single sign on for the Swagger docs to test backend
|
|
||||||
# bearer_scheme = HTTPBearer(scheme_name="BearerAuth", description="Paste **Bearer <JWT>**")
|
|
||||||
|
|
||||||
|
|
||||||
# async def get_authenticated_user(
|
async def get_authenticated_user(
|
||||||
# creds: HTTPAuthorizationCredentials = Security(bearer_scheme),
|
user: Optional[User] = Depends(_auth_dependency),
|
||||||
# ) -> SimpleNamespace:
|
) -> User:
|
||||||
# """
|
"""
|
||||||
# Extract and validate the JWT presented in the Authorization header.
|
Get authenticated user with environment-controlled behavior:
|
||||||
# """
|
- If REQUIRE_AUTHENTICATION=true: Enforces authentication (raises 401 if not authenticated)
|
||||||
# if creds is None: # header missing
|
- If REQUIRE_AUTHENTICATION=false: Falls back to default user if not authenticated
|
||||||
# raise HTTPException(status_code=401, detail="Not authenticated")
|
|
||||||
|
|
||||||
# if creds.scheme.lower() != "bearer": # shouldn't happen extra guard
|
Always returns a User object for consistent typing.
|
||||||
# raise HTTPException(status_code=401, detail="Invalid authentication scheme")
|
"""
|
||||||
|
if user is None:
|
||||||
|
# When authentication is optional and user is None, use default user
|
||||||
|
try:
|
||||||
|
user = await get_default_user()
|
||||||
|
except Exception as e:
|
||||||
|
# Convert any get_default_user failure into a proper HTTP 500 error
|
||||||
|
logger.error(f"Failed to create default user: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to create default user: {str(e)}")
|
||||||
|
|
||||||
# token = creds.credentials
|
return user
|
||||||
# try:
|
|
||||||
# payload = jwt.decode(
|
|
||||||
# token, os.getenv("FASTAPI_USERS_JWT_SECRET", "super_secret"), algorithms=["HS256"]
|
|
||||||
# )
|
|
||||||
|
|
||||||
# auth_data = SimpleNamespace(id=UUID(payload["user_id"]))
|
|
||||||
# return auth_data
|
|
||||||
|
|
||||||
# except jwt.ExpiredSignatureError:
|
|
||||||
# raise HTTPException(status_code=401, detail="Token has expired")
|
|
||||||
# except jwt.InvalidTokenError:
|
|
||||||
# raise HTTPException(status_code=401, detail="Invalid token")
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).resolve().parent
|
ROOT_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
@ -6,3 +7,21 @@ ROOT_DIR = Path(__file__).resolve().parent
|
||||||
def get_absolute_path(path_from_root: str) -> str:
|
def get_absolute_path(path_from_root: str) -> str:
|
||||||
absolute_path = ROOT_DIR / path_from_root
|
absolute_path = ROOT_DIR / path_from_root
|
||||||
return str(absolute_path.resolve())
|
return str(absolute_path.resolve())
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_absolute_path(path: str) -> str:
|
||||||
|
"""Ensures a path is absolute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The path to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path as string
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
raise ValueError("Path cannot be None")
|
||||||
|
path_obj = Path(path).expanduser()
|
||||||
|
if path_obj.is_absolute():
|
||||||
|
return str(path_obj.resolve())
|
||||||
|
|
||||||
|
raise ValueError(f"Path must be absolute. Got relative path: {path}")
|
||||||
|
|
|
||||||
1
cognee/tests/unit/api/__init__.py
Normal file
1
cognee/tests/unit/api/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Test package for API tests
|
||||||
|
|
@ -0,0 +1,246 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, AsyncMock, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from types import SimpleNamespace
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from cognee.api.client import app
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures for reuse across test classes
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_default_user():
|
||||||
|
"""Mock default user for testing."""
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_authenticated_user():
|
||||||
|
"""Mock authenticated user for testing."""
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
return User(
|
||||||
|
id=uuid4(),
|
||||||
|
email="auth@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
tenant_id=uuid4(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationEndpoints:
|
||||||
|
"""Test that API endpoints work correctly with conditional authentication."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
"""Create a test client."""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
def test_health_endpoint_no_auth_required(self, client):
|
||||||
|
"""Test that health endpoint works without authentication."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code in [200, 503] # 503 is also acceptable for health checks
|
||||||
|
|
||||||
|
def test_root_endpoint_no_auth_required(self, client):
|
||||||
|
"""Test that root endpoint works without authentication."""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"message": "Hello, World, I am alive!"}
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"cognee.api.client.REQUIRE_AUTHENTICATION",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
def test_openapi_schema_no_global_security(self, client):
|
||||||
|
"""Test that OpenAPI schema doesn't require global authentication."""
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
schema = response.json()
|
||||||
|
|
||||||
|
# Should not have global security requirement
|
||||||
|
global_security = schema.get("security", [])
|
||||||
|
assert global_security == []
|
||||||
|
|
||||||
|
# But should still have security schemes defined
|
||||||
|
security_schemes = schema.get("components", {}).get("securitySchemes", {})
|
||||||
|
assert "BearerAuth" in security_schemes
|
||||||
|
assert "CookieAuth" in security_schemes
|
||||||
|
|
||||||
|
@patch("cognee.api.v1.add.add")
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
@patch(
|
||||||
|
"cognee.api.client.REQUIRE_AUTHENTICATION",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
def test_add_endpoint_with_conditional_auth(
|
||||||
|
self, mock_get_default_user, mock_add, client, mock_default_user
|
||||||
|
):
|
||||||
|
"""Test add endpoint works with conditional authentication."""
|
||||||
|
mock_get_default_user.return_value = mock_default_user
|
||||||
|
mock_add.return_value = MagicMock(
|
||||||
|
model_dump=lambda: {"status": "success", "pipeline_run_id": str(uuid4())}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test file upload without authentication
|
||||||
|
files = {"data": ("test.txt", b"test content", "text/plain")}
|
||||||
|
form_data = {"datasetName": "test_dataset"}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/add", files=files, data=form_data)
|
||||||
|
|
||||||
|
assert mock_get_default_user.call_count == 1
|
||||||
|
|
||||||
|
# Core test: authentication is not required (should not get 401)
|
||||||
|
assert response.status_code != 401
|
||||||
|
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
@patch(
|
||||||
|
"cognee.api.client.REQUIRE_AUTHENTICATION",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
def test_conditional_authentication_works_with_current_environment(
|
||||||
|
self, mock_get_default_user, client
|
||||||
|
):
|
||||||
|
"""Test that conditional authentication works with the current environment setup."""
|
||||||
|
# Since REQUIRE_AUTHENTICATION defaults to "false", we expect endpoints to work without auth
|
||||||
|
# This tests the actual integration behavior
|
||||||
|
|
||||||
|
mock_get_default_user.return_value = SimpleNamespace(
|
||||||
|
id=uuid4(), email="default@example.com", is_active=True, tenant_id=uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
files = {"data": ("test.txt", b"test content", "text/plain")}
|
||||||
|
form_data = {"datasetName": "test_dataset"}
|
||||||
|
|
||||||
|
response = client.post("/api/v1/add", files=files, data=form_data)
|
||||||
|
|
||||||
|
assert mock_get_default_user.call_count == 1
|
||||||
|
|
||||||
|
# Core test: authentication is not required (should not get 401)
|
||||||
|
assert response.status_code != 401
|
||||||
|
# Note: This test verifies conditional authentication works in the current environment
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationBehavior:
|
||||||
|
"""Test the behavior of conditional authentication across different endpoints."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint,method",
|
||||||
|
[
|
||||||
|
("/api/v1/search", "GET"),
|
||||||
|
("/api/v1/datasets", "GET"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
def test_get_endpoints_work_without_auth(
|
||||||
|
self, mock_get_default, client, endpoint, method, mock_default_user
|
||||||
|
):
|
||||||
|
"""Test that GET endpoints work without authentication (with current environment)."""
|
||||||
|
mock_get_default.return_value = mock_default_user
|
||||||
|
|
||||||
|
if method == "GET":
|
||||||
|
response = client.get(endpoint)
|
||||||
|
elif method == "POST":
|
||||||
|
response = client.post(endpoint, json={})
|
||||||
|
|
||||||
|
assert mock_get_default.call_count == 1
|
||||||
|
|
||||||
|
# Should not return 401 Unauthorized (authentication is optional by default)
|
||||||
|
assert response.status_code != 401
|
||||||
|
|
||||||
|
# May return other errors due to missing data/config, but not auth errors
|
||||||
|
if response.status_code >= 400:
|
||||||
|
# Check that it's not an authentication error
|
||||||
|
try:
|
||||||
|
error_detail = response.json().get("detail", "")
|
||||||
|
assert "authenticate" not in error_detail.lower()
|
||||||
|
assert "unauthorized" not in error_detail.lower()
|
||||||
|
except Exception:
|
||||||
|
pass # If response is not JSON, that's fine
|
||||||
|
|
||||||
|
gsm_mod = importlib.import_module("cognee.modules.settings.get_settings")
|
||||||
|
|
||||||
|
@patch.object(gsm_mod, "get_vectordb_config")
|
||||||
|
@patch.object(gsm_mod, "get_llm_config")
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
def test_settings_endpoint_integration(
|
||||||
|
self, mock_get_default, mock_llm_config, mock_vector_config, client, mock_default_user
|
||||||
|
):
|
||||||
|
"""Test that settings endpoint integration works with conditional authentication."""
|
||||||
|
mock_get_default.return_value = mock_default_user
|
||||||
|
|
||||||
|
# Mock configurations to avoid validation errors
|
||||||
|
mock_llm_config.return_value = SimpleNamespace(
|
||||||
|
llm_provider="openai",
|
||||||
|
llm_model="gpt-4o",
|
||||||
|
llm_endpoint=None,
|
||||||
|
llm_api_version=None,
|
||||||
|
llm_api_key="test_key_1234567890",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_vector_config.return_value = SimpleNamespace(
|
||||||
|
vector_db_provider="lancedb",
|
||||||
|
vector_db_url="localhost:5432", # Must be string, not None
|
||||||
|
vector_db_key="test_vector_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/settings")
|
||||||
|
|
||||||
|
assert mock_get_default.call_count == 1
|
||||||
|
|
||||||
|
# Core test: authentication is not required (should not get 401)
|
||||||
|
assert response.status_code != 401
|
||||||
|
# Note: This test verifies conditional authentication works for settings endpoint
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationErrorHandling:
|
||||||
|
"""Test error handling in conditional authentication."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
def test_get_default_user_fails(self, mock_get_default, client):
|
||||||
|
"""Test behavior when get_default_user fails (with current environment)."""
|
||||||
|
mock_get_default.side_effect = Exception("Database connection failed")
|
||||||
|
|
||||||
|
# The error should propagate - either as a 500 error or as an exception
|
||||||
|
files = {"data": ("test.txt", b"test content", "text/plain")}
|
||||||
|
form_data = {"datasetName": "test_dataset"}
|
||||||
|
|
||||||
|
# Test that the exception is properly converted to HTTP 500
|
||||||
|
response = client.post("/api/v1/add", files=files, data=form_data)
|
||||||
|
|
||||||
|
# Should return HTTP 500 Internal Server Error when get_default_user fails
|
||||||
|
assert response.status_code == 500
|
||||||
|
|
||||||
|
# Check that the error message is informative
|
||||||
|
error_detail = response.json().get("detail", "")
|
||||||
|
assert "Failed to create default user" in error_detail
|
||||||
|
# The exact error message may vary depending on the actual database connection
|
||||||
|
# The important thing is that we get a 500 error when user creation fails
|
||||||
|
|
||||||
|
def test_current_environment_configuration(self):
|
||||||
|
"""Test that current environment configuration is working properly."""
|
||||||
|
# This tests the actual module state without trying to change it
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be a boolean value (the parsing logic works)
|
||||||
|
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||||
|
|
||||||
|
# In default environment, should be False
|
||||||
|
assert not REQUIRE_AUTHENTICATION
|
||||||
1
cognee/tests/unit/modules/users/__init__.py
Normal file
1
cognee/tests/unit/modules/users/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# Test package for user module tests
|
||||||
|
|
@ -0,0 +1,277 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
from types import SimpleNamespace
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
gau_mod = importlib.import_module("cognee.modules.users.methods.get_authenticated_user")
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthentication:
|
||||||
|
"""Test cases for conditional authentication functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_require_authentication_false_no_token_returns_default_user(
|
||||||
|
self, mock_get_default
|
||||||
|
):
|
||||||
|
"""Test that when REQUIRE_AUTHENTICATION=false and no token, returns default user."""
|
||||||
|
# Mock the default user
|
||||||
|
mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True)
|
||||||
|
mock_get_default.return_value = mock_default_user
|
||||||
|
|
||||||
|
# Use gau_mod.get_authenticated_user instead
|
||||||
|
|
||||||
|
# Test with None user (no authentication)
|
||||||
|
result = await gau_mod.get_authenticated_user(user=None)
|
||||||
|
|
||||||
|
assert result == mock_default_user
|
||||||
|
mock_get_default.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_require_authentication_false_with_valid_user_returns_user(
|
||||||
|
self, mock_get_default
|
||||||
|
):
|
||||||
|
"""Test that when REQUIRE_AUTHENTICATION=false and valid user, returns that user."""
|
||||||
|
mock_authenticated_user = User(
|
||||||
|
id=uuid4(),
|
||||||
|
email="user@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use gau_mod.get_authenticated_user instead
|
||||||
|
|
||||||
|
# Test with authenticated user
|
||||||
|
result = await gau_mod.get_authenticated_user(user=mock_authenticated_user)
|
||||||
|
|
||||||
|
assert result == mock_authenticated_user
|
||||||
|
mock_get_default.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_require_authentication_true_with_user_returns_user(self, mock_get_default):
|
||||||
|
"""Test that when REQUIRE_AUTHENTICATION=true and user present, returns user."""
|
||||||
|
mock_authenticated_user = User(
|
||||||
|
id=uuid4(),
|
||||||
|
email="user@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use gau_mod.get_authenticated_user instead
|
||||||
|
|
||||||
|
result = await gau_mod.get_authenticated_user(user=mock_authenticated_user)
|
||||||
|
|
||||||
|
assert result == mock_authenticated_user
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationIntegration:
|
||||||
|
"""Integration tests that test the full authentication flow."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fastapi_users_dependency_creation(self):
|
||||||
|
"""Test that FastAPI Users dependency can be created correctly."""
|
||||||
|
from cognee.modules.users.get_fastapi_users import get_fastapi_users
|
||||||
|
|
||||||
|
fastapi_users = get_fastapi_users()
|
||||||
|
|
||||||
|
# Test that we can create optional dependency
|
||||||
|
optional_dependency = fastapi_users.current_user(optional=True, active=True)
|
||||||
|
assert callable(optional_dependency)
|
||||||
|
|
||||||
|
# Test that we can create required dependency
|
||||||
|
required_dependency = fastapi_users.current_user(active=True) # optional=False by default
|
||||||
|
assert callable(required_dependency)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_conditional_authentication_function_exists(self):
|
||||||
|
"""Test that the conditional authentication function can be imported and used."""
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
get_authenticated_user,
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should be callable
|
||||||
|
assert callable(get_authenticated_user)
|
||||||
|
|
||||||
|
# REQUIRE_AUTHENTICATION should be a boolean
|
||||||
|
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||||
|
|
||||||
|
# Currently should be False (optional authentication)
|
||||||
|
assert not REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationEnvironmentVariables:
|
||||||
|
"""Test environment variable handling."""
|
||||||
|
|
||||||
|
def test_require_authentication_default_false(self):
|
||||||
|
"""Test that REQUIRE_AUTHENTICATION defaults to false when imported with no env vars."""
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
# Remove module from cache to force fresh import
|
||||||
|
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# Import after patching environment - module will see empty environment
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
importlib.invalidate_caches()
|
||||||
|
assert not REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
|
def test_require_authentication_true(self):
|
||||||
|
"""Test that REQUIRE_AUTHENTICATION=true is parsed correctly when imported."""
|
||||||
|
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "true"}):
|
||||||
|
# Remove module from cache to force fresh import
|
||||||
|
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=true
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
|
def test_require_authentication_false_explicit(self):
|
||||||
|
"""Test that REQUIRE_AUTHENTICATION=false is parsed correctly when imported."""
|
||||||
|
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": "false"}):
|
||||||
|
# Remove module from cache to force fresh import
|
||||||
|
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# Import after patching environment - module will see REQUIRE_AUTHENTICATION=false
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
|
def test_require_authentication_case_insensitive(self):
|
||||||
|
"""Test that environment variable parsing is case insensitive when imported."""
|
||||||
|
test_cases = ["TRUE", "True", "tRuE", "FALSE", "False", "fAlSe"]
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
with patch.dict(os.environ, {"REQUIRE_AUTHENTICATION": case}):
|
||||||
|
# Remove module from cache to force fresh import
|
||||||
|
module_name = "cognee.modules.users.methods.get_authenticated_user"
|
||||||
|
if module_name in sys.modules:
|
||||||
|
del sys.modules[module_name]
|
||||||
|
|
||||||
|
# Import after patching environment
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = case.lower() == "true"
|
||||||
|
assert REQUIRE_AUTHENTICATION == expected, f"Failed for case: {case}"
|
||||||
|
|
||||||
|
def test_current_require_authentication_value(self):
|
||||||
|
"""Test that the current REQUIRE_AUTHENTICATION module value is as expected."""
|
||||||
|
from cognee.modules.users.methods.get_authenticated_user import (
|
||||||
|
REQUIRE_AUTHENTICATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The module-level variable should currently be False (set at import time)
|
||||||
|
assert isinstance(REQUIRE_AUTHENTICATION, bool)
|
||||||
|
assert not REQUIRE_AUTHENTICATION
|
||||||
|
|
||||||
|
|
||||||
|
class TestConditionalAuthenticationEdgeCases:
|
||||||
|
"""Test edge cases and error scenarios."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_get_default_user_raises_exception(self, mock_get_default):
|
||||||
|
"""Test behavior when get_default_user raises an exception."""
|
||||||
|
mock_get_default.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
# This should propagate the exception
|
||||||
|
with pytest.raises(Exception, match="Database error"):
|
||||||
|
await gau_mod.get_authenticated_user(user=None)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_user_type_consistency(self, mock_get_default):
|
||||||
|
"""Test that the function always returns the same type."""
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com", is_active=True)
|
||||||
|
mock_get_default.return_value = mock_default_user
|
||||||
|
|
||||||
|
# Test with user
|
||||||
|
result1 = await gau_mod.get_authenticated_user(user=mock_user)
|
||||||
|
assert result1 == mock_user
|
||||||
|
|
||||||
|
# Test with None
|
||||||
|
result2 = await gau_mod.get_authenticated_user(user=None)
|
||||||
|
assert result2 == mock_default_user
|
||||||
|
|
||||||
|
# Both should have user-like interface
|
||||||
|
assert hasattr(result1, "id")
|
||||||
|
assert hasattr(result1, "email")
|
||||||
|
assert result1.id == mock_user.id
|
||||||
|
assert result1.email == mock_user.email
|
||||||
|
assert hasattr(result2, "id")
|
||||||
|
assert hasattr(result2, "email")
|
||||||
|
assert result2.id == mock_default_user.id
|
||||||
|
assert result2.email == mock_default_user.email
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestAuthenticationScenarios:
|
||||||
|
"""Test specific authentication scenarios that could occur in FastAPI Users."""
|
||||||
|
|
||||||
|
@patch.object(gau_mod, "get_default_user", new_callable=AsyncMock)
|
||||||
|
async def test_fallback_to_default_user_scenarios(self, mock_get_default):
|
||||||
|
"""
|
||||||
|
Test fallback to default user for all scenarios where FastAPI Users returns None:
|
||||||
|
- No JWT/Cookie present
|
||||||
|
- Invalid JWT/Cookie
|
||||||
|
- Valid JWT but user doesn't exist in database
|
||||||
|
- Valid JWT but user is inactive (active=True requirement)
|
||||||
|
|
||||||
|
All these scenarios result in FastAPI Users returning None when optional=True,
|
||||||
|
which should trigger fallback to default user.
|
||||||
|
"""
|
||||||
|
mock_default_user = SimpleNamespace(id=uuid4(), email="default@example.com")
|
||||||
|
mock_get_default.return_value = mock_default_user
|
||||||
|
|
||||||
|
# All the above scenarios result in user=None being passed to our function
|
||||||
|
result = await gau_mod.get_authenticated_user(user=None)
|
||||||
|
assert result == mock_default_user
|
||||||
|
mock_get_default.assert_called_once()
|
||||||
|
|
||||||
|
async def test_scenario_valid_active_user(self):
|
||||||
|
"""Scenario: Valid JWT and user exists and is active → returns the user."""
|
||||||
|
mock_user = User(
|
||||||
|
id=uuid4(),
|
||||||
|
email="active@example.com",
|
||||||
|
hashed_password="hashed",
|
||||||
|
is_active=True,
|
||||||
|
is_verified=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use gau_mod.get_authenticated_user instead
|
||||||
|
|
||||||
|
result = await gau_mod.get_authenticated_user(user=mock_user)
|
||||||
|
assert result == mock_user
|
||||||
|
|
@ -4,8 +4,9 @@ import pytest
|
||||||
from unittest.mock import patch, mock_open
|
from unittest.mock import patch, mock_open
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from cognee.root_dir import ensure_absolute_path
|
||||||
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
|
from cognee.infrastructure.files.utils.get_file_content_hash import get_file_content_hash
|
||||||
from cognee.shared.utils import get_anonymous_id
|
from cognee.shared.utils import get_anonymous_id
|
||||||
|
|
||||||
|
|
@ -52,3 +53,21 @@ async def test_get_file_content_hash_stream():
|
||||||
expected_hash = hashlib.md5(b"test_data").hexdigest()
|
expected_hash = hashlib.md5(b"test_data").hexdigest()
|
||||||
result = await get_file_content_hash(stream)
|
result = await get_file_content_hash(stream)
|
||||||
assert result == expected_hash
|
assert result == expected_hash
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_root_dir_absolute_paths():
|
||||||
|
"""Test absolute path handling in root_dir.py"""
|
||||||
|
# Test with absolute path
|
||||||
|
abs_path = "C:/absolute/path" if os.name == "nt" else "/absolute/path"
|
||||||
|
result = ensure_absolute_path(abs_path)
|
||||||
|
assert result == str(Path(abs_path).resolve())
|
||||||
|
|
||||||
|
# Test with relative path (should fail)
|
||||||
|
rel_path = "relative/path"
|
||||||
|
with pytest.raises(ValueError, match="must be absolute"):
|
||||||
|
ensure_absolute_path(rel_path)
|
||||||
|
|
||||||
|
# Test with None path
|
||||||
|
with pytest.raises(ValueError, match="cannot be None"):
|
||||||
|
ensure_absolute_path(None)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue