feat: Adds api and tool logging to redis (#2000)

<!-- .github/pull_request_template.md -->

## Description

This PR implements a logging system that:

- Logs API endpoint and MCP tool calls to Redis
- Captures function/tool name, type (api_endpoint/mcp_tool), parameters,
return values, user information, timing, and metadata
- Stores logs in Redis with configurable TTL (default 7 days)
- Is transparent to existing code (decorator pattern)
- Handles all parameter types gracefully (including non-serializable
values)
- Works independently of the caching system

## Acceptance Criteria

Deploy Redis locally and set USAGE_LOGGING to True as well as
CACHE_BACKEND to redis.

Usage of mcp tools (cognify, save_interaction, search, list_data,
delete, prune, cognify_status) and default cognee endpoints (add,
cognify, search, memify) should log the following entries into redis:

- timestamp
- type
- function_name
- user_id
- parameters
- result
- success
- error
- duration_ms
- metadata


## Type of Change
<!-- Please check the relevant option -->
- [ ] Bug fix (non-breaking change that fixes an issue)
- [x] New feature (non-breaking change that adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Documentation update
- [ ] Code refactoring
- [ ] Performance improvement
- [ ] Other (please specify):

## Screenshots/Videos (if applicable)
<!-- Add screenshots or videos to help explain your changes -->

## Pre-submission Checklist
<!-- Please check all boxes that apply before submitting your PR -->
- [x] **I have tested my changes thoroughly before submitting this PR**
- [x] **This PR contains minimal changes necessary to address the
issue/feature**
- [x] My code follows the project's coding standards and style
guidelines
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added necessary documentation (if applicable)
- [x] All new and existing tests pass
- [x] I have searched existing PRs to ensure this change hasn't been
submitted already
- [x] I have linked any relevant issues in the description
- [x] My commits have clear and descriptive messages

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Usage logging added across API endpoints and MCP tools to record
per-user/tool activity.
* Configuration options to enable logging and set retention (default 7
days).
* Retrieve recent usage logs with configurable limits and
most-recent-first ordering.

* **Tests**
* New unit, integration, and end-to-end tests validating logging,
storage, retrieval, TTL, and edge cases.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Vasilije 2026-01-19 17:05:01 +00:00 committed by GitHub
commit 6a01ad3a00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1308 additions and 9 deletions

View file

@ -659,3 +659,51 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pipeline_cache.py
run_usage_logger_test:
name: Usage logger test (API/MCP)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 3s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "redis"
- name: Install cognee-mcp (local version)
shell: bash
run: |
uv pip install -e ./cognee-mcp
- name: Run api/tool usage logger
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
USAGE_LOGGING: true
CACHE_BACKEND: 'redis'
run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO

View file

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
from cognee.shared.usage_logger import log_usage
import importlib.util
from contextlib import redirect_stdout
import mcp.types as types
@ -91,6 +92,7 @@ async def health_check(request):
@mcp.tool()
@log_usage(function_name="MCP cognify", log_type="mcp_tool")
async def cognify(
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
) -> list:
@ -257,6 +259,7 @@ async def cognify(
@mcp.tool(
name="save_interaction", description="Logs user-agent interactions and query-answer pairs"
)
@log_usage(function_name="MCP save_interaction", log_type="mcp_tool")
async def save_interaction(data: str) -> list:
"""
Transform and save a user-agent interaction into structured knowledge.
@ -316,6 +319,7 @@ async def save_interaction(data: str) -> list:
@mcp.tool()
@log_usage(function_name="MCP search", log_type="mcp_tool")
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -496,6 +500,7 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
@mcp.tool()
@log_usage(function_name="MCP list_data", log_type="mcp_tool")
async def list_data(dataset_id: str = None) -> list:
"""
List all datasets and their data items with IDs for deletion operations.
@ -624,6 +629,7 @@ async def list_data(dataset_id: str = None) -> list:
@mcp.tool()
@log_usage(function_name="MCP delete", log_type="mcp_tool")
async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
"""
Delete specific data from a dataset in the Cognee knowledge graph.
@ -703,6 +709,7 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
@mcp.tool()
@log_usage(function_name="MCP prune", log_type="mcp_tool")
async def prune():
"""
Reset the Cognee knowledge graph by removing all stored information.
@ -739,6 +746,7 @@ async def prune():
@mcp.tool()
@log_usage(function_name="MCP cognify_status", log_type="mcp_tool")
async def cognify_status():
"""
Get the current status of the cognify pipeline.

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -19,6 +20,7 @@ def get_add_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/add", log_type="api_endpoint")
async def add(
data: List[UploadFile] = File(default=None),
datasetName: Optional[str] = Form(default=None),

View file

@ -29,6 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -57,6 +58,7 @@ def get_cognify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/cognify", log_type="api_endpoint")
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Transform datasets into structured knowledge graphs through cognitive processing.

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -35,6 +36,7 @@ def get_memify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/memify", log_type="api_endpoint")
async def memify(payload: MemifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Enrichment pipeline in Cognee, can work with already built graphs. If no data is provided existing knowledge graph will be used as data,

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.exceptions import CogneeValidationError
@ -75,6 +76,7 @@ def get_search_router() -> APIRouter:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], List])
@log_usage(function_name="POST /v1/search", log_type="api_endpoint")
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.

View file

@ -8,10 +8,13 @@ class CacheDBInterface(ABC):
Provides a common interface for lock acquisition, release, and context-managed locking.
"""
def __init__(self, host: str, port: int, lock_key: str):
def __init__(
self, host: str, port: int, lock_key: str = "default_lock", log_key: str = "usage_logs"
):
self.host = host
self.port = port
self.lock_key = lock_key
self.log_key = log_key
self.lock = None
@abstractmethod
@ -77,3 +80,37 @@ class CacheDBInterface(ABC):
Gracefully close any async connections.
"""
pass
@abstractmethod
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to cache.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If cache connection fails or times out.
"""
pass
@abstractmethod
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
pass

View file

@ -13,6 +13,8 @@ class CacheConfig(BaseSettings):
- cache_port: Port number for the cache service.
- agentic_lock_expire: Automatic lock expiration time (in seconds).
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
- usage_logging: Enable/disable usage logging for API endpoints and MCP tools.
- usage_logging_ttl: Time-to-live for usage logs in seconds (default: 7 days).
"""
cache_backend: Literal["redis", "fs"] = "fs"
@ -24,6 +26,8 @@ class CacheConfig(BaseSettings):
cache_password: Optional[str] = None
agentic_lock_expire: int = 240
agentic_lock_timeout: int = 300
usage_logging: bool = False
usage_logging_ttl: int = 604800
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -38,6 +42,8 @@ class CacheConfig(BaseSettings):
"cache_password": self.cache_password,
"agentic_lock_expire": self.agentic_lock_expire,
"agentic_lock_timeout": self.agentic_lock_timeout,
"usage_logging": self.usage_logging,
"usage_logging_ttl": self.usage_logging_ttl,
}

View file

@ -89,6 +89,27 @@ class FSCacheAdapter(CacheDBInterface):
return None
return json.loads(value)
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Usage logging is not supported in filesystem cache backend.
This method is a no-op to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, skipping")
pass
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Usage logging is not supported in filesystem cache backend.
This method returns an empty list to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, returning empty list")
return []
async def close(self):
if self.cache is not None:
self.cache.expire()

View file

@ -1,7 +1,6 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
import os
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
@ -17,6 +16,7 @@ def create_cache_engine(
cache_username: str,
cache_password: str,
lock_key: str,
log_key: str,
agentic_lock_expire: int = 240,
agentic_lock_timeout: int = 300,
):
@ -30,6 +30,7 @@ def create_cache_engine(
- cache_username: Username to authenticate with.
- cache_password: Password to authenticate with.
- lock_key: Identifier used for the locking resource.
- log_key: Identifier used for usage logging.
- agentic_lock_expire: Duration to hold the lock after acquisition.
- agentic_lock_timeout: Max time to wait for the lock before failing.
@ -37,7 +38,7 @@ def create_cache_engine(
--------
- CacheDBInterface: An instance of the appropriate cache adapter.
"""
if config.caching:
if config.caching or config.usage_logging:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
if config.cache_backend == "redis":
@ -47,6 +48,7 @@ def create_cache_engine(
username=cache_username,
password=cache_password,
lock_name=lock_key,
log_key=log_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
@ -61,7 +63,10 @@ def create_cache_engine(
return None
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
def get_cache_engine(
lock_key: Optional[str] = "default_lock",
log_key: Optional[str] = "usage_logs",
) -> Optional[CacheDBInterface]:
"""
Returns a cache adapter instance using current context configuration.
"""
@ -72,6 +77,7 @@ def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
cache_username=config.cache_username,
cache_password=config.cache_password,
lock_key=lock_key,
log_key=log_key,
agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout,
)

View file

@ -17,13 +17,14 @@ class RedisAdapter(CacheDBInterface):
host,
port,
lock_name="default_lock",
log_key="usage_logs",
username=None,
password=None,
timeout=240,
blocking_timeout=300,
connection_timeout=30,
):
super().__init__(host, port, lock_name)
super().__init__(host, port, lock_name, log_key)
self.host = host
self.port = port
@ -177,6 +178,64 @@ class RedisAdapter(CacheDBInterface):
entries = await self.async_redis.lrange(session_key, 0, -1)
return [json.loads(e) for e in entries]
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to Redis.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If Redis connection fails or times out.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
await self.async_redis.rpush(usage_logs_key, json.dumps(log_entry))
if ttl is not None:
await self.async_redis.expire(usage_logs_key, ttl)
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while logging usage: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while logging usage to Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
entries = await self.async_redis.lrange(usage_logs_key, -limit, -1)
return [json.loads(e) for e in reversed(entries)] if entries else []
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while retrieving usage logs: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while retrieving usage logs from Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def close(self):
"""
Gracefully close the async Redis connection.

View file

@ -4,6 +4,4 @@ Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various shared utility errors
"""
from .exceptions import (
IngestionError,
)
from .exceptions import IngestionError, UsageLoggerError

View file

@ -1,4 +1,4 @@
from cognee.exceptions import CogneeValidationError
from cognee.exceptions import CogneeConfigurationError, CogneeValidationError
from fastapi import status
@ -10,3 +10,13 @@ class IngestionError(CogneeValidationError):
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)
class UsageLoggerError(CogneeConfigurationError):
def __init__(
self,
message: str = "Usage logging configuration is invalid.",
name: str = "UsageLoggerError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1,332 @@
import asyncio
import inspect
import os
from datetime import datetime, timezone
from functools import singledispatch, wraps
from typing import Any, Callable, Optional
from uuid import UUID
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
from cognee.shared.exceptions import UsageLoggerError
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger("usage_logger")
@singledispatch
def _sanitize_value(value: Any) -> Any:
"""Default handler for JSON serialization - converts to string."""
try:
str_repr = str(value)
if str_repr.startswith("<") and str_repr.endswith(">"):
return f"<cannot be serialized: {type(value).__name__}>"
return str_repr
except Exception:
return f"<cannot be serialized: {type(value).__name__}>"
@_sanitize_value.register(type(None))
def _(value: None) -> None:
"""Handle None values - returns None as-is."""
return None
@_sanitize_value.register(str)
@_sanitize_value.register(int)
@_sanitize_value.register(float)
@_sanitize_value.register(bool)
def _(value: str | int | float | bool) -> str | int | float | bool:
"""Handle primitive types - returns value as-is since they're JSON-serializable."""
return value
@_sanitize_value.register(UUID)
def _(value: UUID) -> str:
"""Convert UUID to string representation."""
return str(value)
@_sanitize_value.register(datetime)
def _(value: datetime) -> str:
"""Convert datetime to ISO format string."""
return value.isoformat()
@_sanitize_value.register(list)
@_sanitize_value.register(tuple)
def _(value: list | tuple) -> list:
"""Recursively sanitize list or tuple elements."""
return [_sanitize_value(v) for v in value]
@_sanitize_value.register(dict)
def _(value: dict) -> dict:
"""Recursively sanitize dictionary keys and values."""
sanitized = {}
for k, v in value.items():
key_str = k if isinstance(k, str) else _sanitize_dict_key(k)
sanitized[key_str] = _sanitize_value(v)
return sanitized
def _sanitize_dict_key(key: Any) -> str:
"""Convert a non-string dict key to a string."""
sanitized_key = _sanitize_value(key)
if isinstance(sanitized_key, str):
if sanitized_key.startswith("<cannot be serialized"):
return f"<key:{type(key).__name__}>"
return sanitized_key
return str(sanitized_key)
def _get_param_names(func: Callable) -> list[str]:
"""Get parameter names from function signature."""
try:
return list(inspect.signature(func).parameters.keys())
except Exception:
return []
def _get_param_defaults(func: Callable) -> dict[str, Any]:
"""Get parameter defaults from function signature."""
try:
sig = inspect.signature(func)
defaults = {}
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
defaults[param_name] = param.default
return defaults
except Exception:
return {}
def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]:
"""Extract user_id from function arguments if available."""
try:
if "user" in kwargs and kwargs["user"] is not None:
user = kwargs["user"]
if hasattr(user, "id"):
return str(user.id)
for i, param_name in enumerate(param_names):
if i < len(args) and param_name == "user":
user = args[i]
if user is not None and hasattr(user, "id"):
return str(user.id)
return None
except Exception:
return None
def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str], func: Callable) -> dict:
"""Extract function parameters - captures all parameters including defaults, sanitizes for JSON."""
params = {}
for key, value in kwargs.items():
if key != "user":
params[key] = _sanitize_value(value)
if param_names:
for i, param_name in enumerate(param_names):
if i < len(args) and param_name != "user" and param_name not in kwargs:
params[param_name] = _sanitize_value(args[i])
else:
for i, arg_value in enumerate(args):
params[f"arg_{i}"] = _sanitize_value(arg_value)
if param_names:
defaults = _get_param_defaults(func)
for param_name in param_names:
if param_name != "user" and param_name not in params and param_name in defaults:
params[param_name] = _sanitize_value(defaults[param_name])
return params
async def _log_usage_async(
function_name: str,
log_type: str,
user_id: Optional[str],
parameters: dict,
result: Any,
success: bool,
error: Optional[str],
duration_ms: float,
start_time: datetime,
end_time: datetime,
):
"""Asynchronously log function usage to Redis.
Args:
function_name: Name of the function being logged.
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool", "function").
user_id: User identifier, or None to use "unknown".
parameters: Dictionary of function parameters (sanitized).
result: Function return value (will be sanitized).
success: Whether the function executed successfully.
error: Error message if function failed, None otherwise.
duration_ms: Execution duration in milliseconds.
start_time: Function start timestamp.
end_time: Function end timestamp.
Note:
This function silently handles errors to avoid disrupting the original
function execution. Logs are written to Redis with TTL from config.
"""
try:
logger.debug(f"Starting to log usage for {function_name} at {start_time.isoformat()}")
config = get_cache_config()
if not config.usage_logging:
logger.debug("Usage logging disabled, skipping log")
return
logger.debug(f"Getting cache engine for {function_name}")
cache_engine = get_cache_engine()
if cache_engine is None:
logger.warning(
f"Cache engine not available for usage logging (function: {function_name})"
)
return
logger.debug(f"Cache engine obtained for {function_name}")
if user_id is None:
user_id = "unknown"
logger.debug(f"No user_id provided, using 'unknown' for {function_name}")
log_entry = {
"timestamp": start_time.isoformat(),
"type": log_type,
"function_name": function_name,
"user_id": user_id,
"parameters": parameters,
"result": _sanitize_value(result),
"success": success,
"error": error,
"duration_ms": round(duration_ms, 2),
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"metadata": {
"cognee_version": cognee_version,
"environment": os.getenv("ENV", "prod"),
},
}
logger.debug(f"Calling log_usage for {function_name}, user_id={user_id}")
await cache_engine.log_usage(
user_id=user_id,
log_entry=log_entry,
ttl=config.usage_logging_ttl,
)
logger.info(f"Successfully logged usage for {function_name} (user_id={user_id})")
except Exception as e:
logger.error(f"Failed to log usage for {function_name}: {str(e)}", exc_info=True)
def log_usage(function_name: Optional[str] = None, log_type: str = "function"):
"""
Decorator to log function usage to Redis.
This decorator is completely transparent - it doesn't change function behavior.
It logs function name, parameters, result, timing, and user (if available).
Args:
function_name: Optional name for the function (defaults to func.__name__)
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool")
Usage:
@log_usage(function_name="MCP my_mcp_tool", log_type="mcp_tool")
async def my_mcp_tool(...):
# mcp code
@log_usage(function_name="POST API /v1/add", log_type="api_endpoint")
async def add(...):
# endpoint code
"""
def decorator(func: Callable) -> Callable:
"""Inner decorator that wraps the function with usage logging.
Args:
func: The async function to wrap with usage logging.
Returns:
Callable: The wrapped function with usage logging enabled.
Raises:
UsageLoggerError: If the function is not async.
"""
if not inspect.iscoroutinefunction(func):
raise UsageLoggerError(
f"@log_usage requires an async function. Got {func.__name__} which is not async."
)
@wraps(func)
async def async_wrapper(*args, **kwargs):
"""Wrapper function that executes the original function and logs usage.
This wrapper:
- Extracts user ID and parameters from function arguments
- Executes the original function
- Captures result, success status, and any errors
- Logs usage information asynchronously without blocking
Args:
*args: Positional arguments passed to the original function.
**kwargs: Keyword arguments passed to the original function.
Returns:
Any: The return value of the original function.
Raises:
Any exception raised by the original function (re-raised after logging).
"""
config = get_cache_config()
if not config.usage_logging:
return await func(*args, **kwargs)
start_time = datetime.now(timezone.utc)
param_names = _get_param_names(func)
user_id = _extract_user_id(args, kwargs, param_names)
parameters = _extract_parameters(args, kwargs, param_names, func)
result = None
success = True
error = None
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
success = False
error = str(e)
raise
finally:
end_time = datetime.now(timezone.utc)
duration_ms = (end_time - start_time).total_seconds() * 1000
try:
await _log_usage_async(
function_name=function_name or func.__name__,
log_type=log_type,
user_id=user_id,
parameters=parameters,
result=result,
success=success,
error=error,
duration_ms=duration_ms,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
logger.error(
f"Failed to log usage for {function_name or func.__name__}: {str(e)}",
exc_info=True,
)
return async_wrapper
return decorator

View file

@ -0,0 +1,255 @@
"""Integration tests for usage logger with real Redis components."""
import os
import pytest
import asyncio
from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import UUID
from unittest.mock import patch
from cognee.shared.usage_logger import log_usage
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import (
get_cache_engine,
create_cache_engine,
)
@pytest.fixture
def usage_logging_config():
"""Fixture to enable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def usage_logging_disabled():
"""Fixture to disable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "false"
os.environ["CACHE_BACKEND"] = "redis"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def redis_adapter():
"""Real RedisAdapter instance for testing."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
try:
yield RedisAdapter(host="localhost", port=6379, log_key="test_usage_logs")
except Exception as e:
pytest.skip(f"Redis not available: {e}")
@pytest.fixture
def test_user():
"""Test user object."""
return SimpleNamespace(id="test-user-123")
class TestDecoratorBehavior:
"""Test decorator behavior with real components."""
@pytest.mark.asyncio
async def test_decorator_configuration(
self, usage_logging_disabled, usage_logging_config, redis_adapter
):
"""Test decorator skips when disabled and logs when enabled."""
# Test disabled
call_count = 0
@log_usage(function_name="test_func", log_type="test")
async def test_func():
nonlocal call_count
call_count += 1
return "result"
assert await test_func() == "result"
assert call_count == 1
# Test enabled with cache engine None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = None
assert await test_func() == "result"
@pytest.mark.asyncio
async def test_decorator_logging(self, usage_logging_config, redis_adapter, test_user):
"""Test decorator logs to Redis with correct structure."""
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42, user=None):
await asyncio.sleep(0.01)
return {"result": f"{param1}_{param2}"}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
result = await test_func("value1", user=test_user)
assert result == {"result": "value1_42"}
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
log = logs[0]
assert log["function_name"] == "test_func"
assert log["type"] == "test"
assert log["user_id"] == "test-user-123"
assert log["parameters"]["param1"] == "value1"
assert log["parameters"]["param2"] == 42
assert log["success"] is True
assert all(
field in log
for field in [
"timestamp",
"result",
"error",
"duration_ms",
"start_time",
"end_time",
"metadata",
]
)
assert "cognee_version" in log["metadata"]
@pytest.mark.asyncio
async def test_multiple_calls(self, usage_logging_config, redis_adapter, test_user):
"""Test multiple consecutive calls are all logged."""
@log_usage(function_name="multi_test", log_type="test")
async def multi_func(call_num: int, user=None):
return {"call": call_num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
for i in range(3):
await multi_func(i, user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert len(logs) >= 3
call_nums = {log["parameters"]["call_num"] for log in logs[:3]}
assert call_nums == {0, 1, 2}
class TestRealRedisIntegration:
"""Test real Redis integration."""
@pytest.mark.asyncio
async def test_redis_storage_retrieval_and_ttl(
self, usage_logging_config, redis_adapter, test_user
):
"""Test logs are stored, retrieved with correct order/limits, and TTL is set."""
@log_usage(function_name="redis_test", log_type="test")
async def redis_func(data: str, user=None):
return {"processed": data}
@log_usage(function_name="order_test", log_type="test")
async def order_func(num: int, user=None):
return {"num": num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# Storage
await redis_func("test_data", user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["function_name"] == "redis_test"
assert logs[0]["parameters"]["data"] == "test_data"
# Order (most recent first)
for i in range(3):
await order_func(i, user=test_user)
await asyncio.sleep(0.01)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert [log["parameters"]["num"] for log in logs[:3]] == [2, 1, 0]
# Limit
assert len(await redis_adapter.get_usage_logs("test-user-123", limit=2)) == 2
# TTL
ttl = await redis_adapter.async_redis.ttl("test_usage_logs:test-user-123")
assert 0 < ttl <= 604800
class TestEdgeCases:
"""Test edge cases in integration tests."""
@pytest.mark.asyncio
async def test_edge_cases(self, usage_logging_config, redis_adapter, test_user):
"""Test no params, defaults, complex structures, exceptions, None, circular refs."""
@log_usage(function_name="no_params", log_type="test")
async def no_params_func(user=None):
return "result"
@log_usage(function_name="defaults_only", log_type="test")
async def defaults_only_func(param1: str = "default1", param2: int = 42, user=None):
return {"param1": param1, "param2": param2}
@log_usage(function_name="complex_test", log_type="test")
async def complex_func(user=None):
return {
"nested": {
"list": [1, 2, 3],
"uuid": UUID("123e4567-e89b-12d3-a456-426614174000"),
"datetime": datetime(2024, 1, 15, tzinfo=timezone.utc),
}
}
@log_usage(function_name="exception_test", log_type="test")
async def exception_func(user=None):
raise RuntimeError("Test exception")
@log_usage(function_name="none_test", log_type="test")
async def none_func(user=None):
return None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# No parameters
await no_params_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"] == {}
# Default parameters
await defaults_only_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"]["param1"] == "default1"
assert logs[0]["parameters"]["param2"] == 42
# Complex nested structures
await complex_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert isinstance(logs[0]["result"]["nested"]["uuid"], str)
assert isinstance(logs[0]["result"]["nested"]["datetime"], str)
# Exception handling
with pytest.raises(RuntimeError):
await exception_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["success"] is False
assert "Test exception" in logs[0]["error"]
# None return value
assert await none_func(user=test_user) is None
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["result"] is None

View file

@ -0,0 +1,268 @@
import os
import pytest
import pytest_asyncio
import asyncio
from fastapi.testclient import TestClient
import cognee
from cognee.api.client import app
from cognee.modules.users.methods import get_default_user, get_authenticated_user
async def _reset_engines_and_prune():
"""Reset db engine caches and prune data/system."""
try:
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
await vector_engine.engine.dispose(close=True)
except Exception:
pass
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.fixture(scope="session")
def event_loop():
"""Use a single asyncio event loop for this test module."""
loop = asyncio.new_event_loop()
try:
yield loop
finally:
loop.close()
@pytest.fixture(scope="session")
def e2e_config():
"""Configure environment for E2E tests."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
yield
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture(scope="session")
def authenticated_client(test_client):
"""Override authentication to use default user."""
async def override_get_authenticated_user():
return await get_default_user()
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
yield test_client
app.dependency_overrides.pop(get_authenticated_user, None)
@pytest_asyncio.fixture(scope="session")
async def test_data_setup():
"""Set up test data: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_e2e_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest_asyncio.fixture
async def mcp_data_setup():
"""Set up test data for MCP tests: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_mcp_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest.fixture(scope="session")
def test_client():
"""TestClient instance for API calls."""
with TestClient(app) as client:
yield client
@pytest_asyncio.fixture
async def cache_engine(e2e_config):
"""Get cache engine for log verification in test's event loop."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
from cognee.infrastructure.databases.cache.config import get_cache_config
config = get_cache_config()
if not config.usage_logging or config.cache_backend != "redis":
pytest.skip("Redis usage logging not configured")
engine = RedisAdapter(
host=config.cache_host,
port=config.cache_port,
username=config.cache_username,
password=config.cache_password,
log_key="usage_logs",
)
return engine
@pytest.mark.asyncio
async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine):
"""Test that API endpoints succeed and log to Redis."""
user = await get_default_user()
dataset_name = "test_e2e_api_dataset"
add_response = authenticated_client.post(
"/api/v1/add",
data={"datasetName": dataset_name},
files=[
(
"data",
(
"test.txt",
b"Germany is located in Europe right next to the Netherlands.",
"text/plain",
),
)
],
)
assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}"
cognify_response = authenticated_client.post(
"/api/v1/cognify",
json={"datasets": [dataset_name], "run_in_background": False},
)
assert cognify_response.status_code in [200, 201], (
f"Cognify endpoint failed: {cognify_response.text}"
)
search_response = authenticated_client.post(
"/api/v1/search",
json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]},
)
assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}"
logs = await cache_engine.get_usage_logs(str(user.id), limit=20)
add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"]
assert len(add_logs) > 0
assert add_logs[0]["type"] == "api_endpoint"
assert add_logs[0]["user_id"] == str(user.id)
assert add_logs[0]["success"] is True
cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"]
assert len(cognify_logs) > 0
assert cognify_logs[0]["type"] == "api_endpoint"
assert cognify_logs[0]["user_id"] == str(user.id)
assert cognify_logs[0]["success"] is True
search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"]
assert len(search_logs) > 0
assert search_logs[0]["type"] == "api_endpoint"
assert search_logs[0]["user_id"] == str(user.id)
assert search_logs[0]["success"] is True
@pytest.mark.asyncio
async def test_mcp_tool_logging(e2e_config, cache_engine):
"""Test that MCP tools succeed and log to Redis."""
import sys
import importlib.util
from pathlib import Path
await _reset_engines_and_prune()
repo_root = Path(__file__).parent.parent.parent
mcp_src_path = repo_root / "cognee-mcp" / "src"
mcp_server_path = mcp_src_path / "server.py"
if not mcp_server_path.exists():
pytest.skip(f"MCP server not found at {mcp_server_path}")
if str(mcp_src_path) not in sys.path:
sys.path.insert(0, str(mcp_src_path))
spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path)
mcp_server_module = importlib.util.module_from_spec(spec)
import os
original_cwd = os.getcwd()
try:
os.chdir(str(mcp_src_path))
spec.loader.exec_module(mcp_server_module)
finally:
os.chdir(original_cwd)
if mcp_server_module.cognee_client is None:
cognee_client_path = mcp_src_path / "cognee_client.py"
if cognee_client_path.exists():
spec_client = importlib.util.spec_from_file_location(
"cognee_client", cognee_client_path
)
cognee_client_module = importlib.util.module_from_spec(spec_client)
spec_client.loader.exec_module(cognee_client_module)
CogneeClient = cognee_client_module.CogneeClient
mcp_server_module.cognee_client = CogneeClient()
else:
pytest.skip(f"CogneeClient not found at {cognee_client_path}")
test_text = "Germany is located in Europe right next to the Netherlands."
await mcp_server_module.cognify(data=test_text)
await asyncio.sleep(30.0)
list_result = await mcp_server_module.list_data()
assert list_result is not None, "List data should return results"
search_result = await mcp_server_module.search(
search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5
)
assert search_result is not None, "Search should return results"
interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe."
await mcp_server_module.save_interaction(data=interaction_data)
await asyncio.sleep(30.0)
status_result = await mcp_server_module.cognify_status()
assert status_result is not None, "Cognify status should return results"
await mcp_server_module.prune()
await asyncio.sleep(0.5)
logs = await cache_engine.get_usage_logs("unknown", limit=50)
mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"]
assert len(mcp_logs) > 0, (
f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}"
)
assert len(mcp_logs) == 6
function_names = [log.get("function_name") for log in mcp_logs]
expected_tools = [
"MCP cognify",
"MCP list_data",
"MCP search",
"MCP save_interaction",
"MCP cognify_status",
"MCP prune",
]
for expected_tool in expected_tools:
assert expected_tool in function_names, (
f"Should have {expected_tool} log. Found: {function_names}"
)
for log in mcp_logs:
assert log["type"] == "mcp_tool"
assert log["user_id"] == "unknown"
assert log["success"] is True

View file

@ -62,6 +62,8 @@ def test_cache_config_to_dict():
"cache_password": None,
"agentic_lock_expire": 100,
"agentic_lock_timeout": 200,
"usage_logging": False,
"usage_logging_ttl": 604800,
}

View file

@ -0,0 +1,241 @@
"""Unit tests for usage logger core functions."""
import pytest
from datetime import datetime, timezone
from uuid import UUID
from types import SimpleNamespace
from cognee.shared.usage_logger import (
_sanitize_value,
_sanitize_dict_key,
_get_param_names,
_get_param_defaults,
_extract_user_id,
_extract_parameters,
log_usage,
)
from cognee.shared.exceptions import UsageLoggerError
class TestSanitizeValue:
"""Test _sanitize_value function."""
@pytest.mark.parametrize(
"value,expected",
[
(None, None),
("string", "string"),
(42, 42),
(3.14, 3.14),
(True, True),
(False, False),
],
)
def test_basic_types(self, value, expected):
assert _sanitize_value(value) == expected
def test_uuid_and_datetime(self):
"""Test UUID and datetime serialization."""
uuid_val = UUID("123e4567-e89b-12d3-a456-426614174000")
dt = datetime(2024, 1, 15, 12, 30, 45, tzinfo=timezone.utc)
assert _sanitize_value(uuid_val) == "123e4567-e89b-12d3-a456-426614174000"
assert _sanitize_value(dt) == "2024-01-15T12:30:45+00:00"
def test_collections(self):
"""Test list, tuple, and dict serialization."""
assert _sanitize_value(
[1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]
) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None]
assert _sanitize_value((1, "string", True)) == [1, "string", True]
assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == {
"key": "123e4567-e89b-12d3-a456-426614174000"
}
assert _sanitize_value([]) == []
assert _sanitize_value({}) == {}
def test_nested_and_complex(self):
"""Test nested structures and non-serializable types."""
# Nested structure
nested = {"level1": {"level2": {"level3": [1, 2, {"nested": "value"}]}}}
assert _sanitize_value(nested)["level1"]["level2"]["level3"][2]["nested"] == "value"
# Non-serializable
class CustomObject:
def __str__(self):
return "<CustomObject instance>"
result = _sanitize_value(CustomObject())
assert isinstance(result, str)
assert "<cannot be serialized" in result or "<CustomObject" in result
class TestSanitizeDictKey:
"""Test _sanitize_dict_key function."""
@pytest.mark.parametrize(
"key,expected_contains",
[
("simple_key", "simple_key"),
(UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"),
((1, 2, 3), ["1", "2"]),
],
)
def test_key_types(self, key, expected_contains):
result = _sanitize_dict_key(key)
assert isinstance(result, str)
if isinstance(expected_contains, list):
assert all(item in result for item in expected_contains)
else:
assert expected_contains in result
def test_non_serializable_key(self):
class BadKey:
def __str__(self):
return "<BadKey instance>"
result = _sanitize_dict_key(BadKey())
assert isinstance(result, str)
assert "<key:" in result or "<BadKey" in result
class TestGetParamNames:
"""Test _get_param_names function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b, c: None, ["a", "b", "c"]),
(lambda a, b=42, c="default": None, ["a", "b", "c"]),
(lambda a, **kwargs: None, ["a", "kwargs"]),
(lambda *args: None, ["args"]),
],
)
def test_param_extraction(self, func_def, expected):
assert _get_param_names(func_def) == expected
def test_async_function(self):
async def func(a, b):
pass
assert _get_param_names(func) == ["a", "b"]
class TestGetParamDefaults:
"""Test _get_param_defaults function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}),
(lambda a, b, c: None, {}),
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}),
],
)
def test_default_extraction(self, func_def, expected):
assert _get_param_defaults(func_def) == expected
class TestExtractUserId:
"""Test _extract_user_id function."""
def test_user_extraction(self):
"""Test extracting user_id from kwargs and args."""
user1 = SimpleNamespace(id=UUID("123e4567-e89b-12d3-a456-426614174000"))
user2 = SimpleNamespace(id="user-123")
# From kwargs
assert (
_extract_user_id((), {"user": user1}, ["user", "other"])
== "123e4567-e89b-12d3-a456-426614174000"
)
# From args
assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123"
# Not present
assert _extract_user_id(("arg1",), {}, ["param1"]) is None
# None value
assert _extract_user_id((None,), {}, ["user"]) is None
# No id attribute
assert _extract_user_id((SimpleNamespace(name="test"),), {}, ["user"]) is None
class TestExtractParameters:
"""Test _extract_parameters function."""
def test_parameter_extraction(self):
"""Test parameter extraction with various scenarios."""
def func1(param1, param2, user=None):
pass
def func2(param1, param2=42, param3="default", user=None):
pass
def func3():
pass
def func4(param1, user):
pass
# Kwargs only
result = _extract_parameters(
(), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1
)
assert result == {"param1": "v1", "param2": 42}
assert "user" not in result
# Args only
result = _extract_parameters(("v1", 42), {}, _get_param_names(func1), func1)
assert result == {"param1": "v1", "param2": 42}
# Mixed args/kwargs
result = _extract_parameters(("v1",), {"param3": "v3"}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param3"] == "v3"
# Defaults included
result = _extract_parameters(("v1",), {}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param2"] == 42 and result["param3"] == "default"
# No parameters
assert _extract_parameters((), {}, _get_param_names(func3), func3) == {}
# User excluded
user = SimpleNamespace(id="user-123")
result = _extract_parameters(("v1", user), {}, _get_param_names(func4), func4)
assert result == {"param1": "v1"} and "user" not in result
# Fallback when inspection fails
class BadFunc:
pass
result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc())
assert "arg_0" in result or "arg_1" in result
class TestDecoratorValidation:
"""Test decorator validation and behavior."""
def test_decorator_validation(self):
"""Test decorator validation and metadata preservation."""
# Sync function raises error
with pytest.raises(UsageLoggerError, match="requires an async function"):
@log_usage()
def sync_func():
pass
# Async function accepted
@log_usage()
async def async_func():
pass
assert callable(async_func)
# Metadata preserved
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42):
"""Test docstring."""
return param1
assert test_func.__name__ == "test_func"
assert "Test docstring" in test_func.__doc__