Merge branch 'dev' into fix-mcp-migrations

This commit is contained in:
Igor Ilic 2026-01-20 16:18:44 +01:00 committed by GitHub
commit c29fb51619
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
51 changed files with 1529 additions and 82 deletions

View file

@ -659,3 +659,51 @@ jobs:
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
run: uv run python ./cognee/tests/test_pipeline_cache.py
run_usage_logger_test:
name: Usage logger test (API/MCP)
runs-on: ubuntu-latest
defaults:
run:
shell: bash
services:
redis:
image: redis:7
ports:
- 6379:6379
options: >-
--health-cmd "redis-cli ping"
--health-interval 5s
--health-timeout 3s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Cognee Setup
uses: ./.github/actions/cognee_setup
with:
python-version: '3.11.x'
extra-dependencies: "redis"
- name: Install cognee-mcp (local version)
shell: bash
run: |
uv pip install -e ./cognee-mcp
- name: Run api/tool usage logger
env:
ENV: dev
LLM_MODEL: ${{ secrets.LLM_MODEL }}
LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }}
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }}
EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }}
EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }}
EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }}
EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }}
GRAPH_DATABASE_PROVIDER: 'kuzu'
USAGE_LOGGING: true
CACHE_BACKEND: 'redis'
run: uv run pytest cognee/tests/test_usage_logger_e2e.py -v --log-level=INFO

View file

@ -8,6 +8,7 @@ from pathlib import Path
from typing import Optional
from cognee.shared.logging_utils import get_logger, setup_logging, get_log_file_location
from cognee.shared.usage_logger import log_usage
import importlib.util
from contextlib import redirect_stdout
import mcp.types as types
@ -91,6 +92,7 @@ async def health_check(request):
@mcp.tool()
@log_usage(function_name="MCP cognify", log_type="mcp_tool")
async def cognify(
data: str, graph_model_file: str = None, graph_model_name: str = None, custom_prompt: str = None
) -> list:
@ -257,6 +259,7 @@ async def cognify(
@mcp.tool(
name="save_interaction", description="Logs user-agent interactions and query-answer pairs"
)
@log_usage(function_name="MCP save_interaction", log_type="mcp_tool")
async def save_interaction(data: str) -> list:
"""
Transform and save a user-agent interaction into structured knowledge.
@ -316,6 +319,7 @@ async def save_interaction(data: str) -> list:
@mcp.tool()
@log_usage(function_name="MCP search", log_type="mcp_tool")
async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
"""
Search and query the knowledge graph for insights, information, and connections.
@ -496,6 +500,7 @@ async def search(search_query: str, search_type: str, top_k: int = 10) -> list:
@mcp.tool()
@log_usage(function_name="MCP list_data", log_type="mcp_tool")
async def list_data(dataset_id: str = None) -> list:
"""
List all datasets and their data items with IDs for deletion operations.
@ -624,6 +629,7 @@ async def list_data(dataset_id: str = None) -> list:
@mcp.tool()
@log_usage(function_name="MCP delete", log_type="mcp_tool")
async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
"""
Delete specific data from a dataset in the Cognee knowledge graph.
@ -703,6 +709,7 @@ async def delete(data_id: str, dataset_id: str, mode: str = "soft") -> list:
@mcp.tool()
@log_usage(function_name="MCP prune", log_type="mcp_tool")
async def prune():
"""
Reset the Cognee knowledge graph by removing all stored information.
@ -739,6 +746,7 @@ async def prune():
@mcp.tool()
@log_usage(function_name="MCP cognify_status", log_type="mcp_tool")
async def cognify_status():
"""
Get the current status of the cognify pipeline.

View file

@ -10,6 +10,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -19,6 +20,7 @@ def get_add_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/add", log_type="api_endpoint")
async def add(
data: List[UploadFile] = File(default=None),
datasetName: Optional[str] = Form(default=None),

View file

@ -29,6 +29,7 @@ from cognee.modules.pipelines.queues.pipeline_run_info_queues import (
)
from cognee.shared.logging_utils import get_logger
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger("api.cognify")
@ -57,6 +58,7 @@ def get_cognify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/cognify", log_type="api_endpoint")
async def cognify(payload: CognifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Transform datasets into structured knowledge graphs through cognitive processing.

View file

@ -12,6 +12,7 @@ from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.modules.pipelines.models import PipelineRunErrored
from cognee.shared.logging_utils import get_logger
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
logger = get_logger()
@ -35,6 +36,7 @@ def get_memify_router() -> APIRouter:
router = APIRouter()
@router.post("", response_model=dict)
@log_usage(function_name="POST /v1/memify", log_type="api_endpoint")
async def memify(payload: MemifyPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Enrichment pipeline in Cognee, can work with already built graphs. If no data is provided existing knowledge graph will be used as data,

View file

@ -13,6 +13,7 @@ from cognee.modules.users.models import User
from cognee.modules.search.operations import get_history
from cognee.modules.users.methods import get_authenticated_user
from cognee.shared.utils import send_telemetry
from cognee.shared.usage_logger import log_usage
from cognee import __version__ as cognee_version
from cognee.infrastructure.databases.exceptions import DatabaseNotCreatedError
from cognee.exceptions import CogneeValidationError
@ -75,6 +76,7 @@ def get_search_router() -> APIRouter:
return JSONResponse(status_code=500, content={"error": str(error)})
@router.post("", response_model=Union[List[SearchResult], List])
@log_usage(function_name="POST /v1/search", log_type="api_endpoint")
async def search(payload: SearchPayloadDTO, user: User = Depends(get_authenticated_user)):
"""
Search for nodes in the graph database.

View file

@ -8,10 +8,13 @@ class CacheDBInterface(ABC):
Provides a common interface for lock acquisition, release, and context-managed locking.
"""
def __init__(self, host: str, port: int, lock_key: str):
def __init__(
self, host: str, port: int, lock_key: str = "default_lock", log_key: str = "usage_logs"
):
self.host = host
self.port = port
self.lock_key = lock_key
self.log_key = log_key
self.lock = None
@abstractmethod
@ -77,3 +80,37 @@ class CacheDBInterface(ABC):
Gracefully close any async connections.
"""
pass
@abstractmethod
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to cache.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If cache connection fails or times out.
"""
pass
@abstractmethod
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
pass

View file

@ -13,6 +13,8 @@ class CacheConfig(BaseSettings):
- cache_port: Port number for the cache service.
- agentic_lock_expire: Automatic lock expiration time (in seconds).
- agentic_lock_timeout: Maximum time (in seconds) to wait for the lock release.
- usage_logging: Enable/disable usage logging for API endpoints and MCP tools.
- usage_logging_ttl: Time-to-live for usage logs in seconds (default: 7 days).
"""
cache_backend: Literal["redis", "fs"] = "fs"
@ -24,6 +26,8 @@ class CacheConfig(BaseSettings):
cache_password: Optional[str] = None
agentic_lock_expire: int = 240
agentic_lock_timeout: int = 300
usage_logging: bool = False
usage_logging_ttl: int = 604800
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -38,6 +42,8 @@ class CacheConfig(BaseSettings):
"cache_password": self.cache_password,
"agentic_lock_expire": self.agentic_lock_expire,
"agentic_lock_timeout": self.agentic_lock_timeout,
"usage_logging": self.usage_logging,
"usage_logging_ttl": self.usage_logging_ttl,
}

View file

@ -89,6 +89,27 @@ class FSCacheAdapter(CacheDBInterface):
return None
return json.loads(value)
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Usage logging is not supported in filesystem cache backend.
This method is a no-op to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, skipping")
pass
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Usage logging is not supported in filesystem cache backend.
This method returns an empty list to satisfy the interface.
"""
logger.warning("Usage logging not supported in FSCacheAdapter, returning empty list")
return []
async def close(self):
if self.cache is not None:
self.cache.expire()

View file

@ -1,7 +1,6 @@
"""Factory to get the appropriate cache coordination engine (e.g., Redis)."""
from functools import lru_cache
import os
from typing import Optional
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.cache_db_interface import CacheDBInterface
@ -17,6 +16,7 @@ def create_cache_engine(
cache_username: str,
cache_password: str,
lock_key: str,
log_key: str,
agentic_lock_expire: int = 240,
agentic_lock_timeout: int = 300,
):
@ -30,6 +30,7 @@ def create_cache_engine(
- cache_username: Username to authenticate with.
- cache_password: Password to authenticate with.
- lock_key: Identifier used for the locking resource.
- log_key: Identifier used for usage logging.
- agentic_lock_expire: Duration to hold the lock after acquisition.
- agentic_lock_timeout: Max time to wait for the lock before failing.
@ -37,7 +38,7 @@ def create_cache_engine(
--------
- CacheDBInterface: An instance of the appropriate cache adapter.
"""
if config.caching:
if config.caching or config.usage_logging:
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
if config.cache_backend == "redis":
@ -47,6 +48,7 @@ def create_cache_engine(
username=cache_username,
password=cache_password,
lock_name=lock_key,
log_key=log_key,
timeout=agentic_lock_expire,
blocking_timeout=agentic_lock_timeout,
)
@ -61,7 +63,10 @@ def create_cache_engine(
return None
def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
def get_cache_engine(
lock_key: Optional[str] = "default_lock",
log_key: Optional[str] = "usage_logs",
) -> Optional[CacheDBInterface]:
"""
Returns a cache adapter instance using current context configuration.
"""
@ -72,6 +77,7 @@ def get_cache_engine(lock_key: Optional[str] = None) -> CacheDBInterface:
cache_username=config.cache_username,
cache_password=config.cache_password,
lock_key=lock_key,
log_key=log_key,
agentic_lock_expire=config.agentic_lock_expire,
agentic_lock_timeout=config.agentic_lock_timeout,
)

View file

@ -17,13 +17,14 @@ class RedisAdapter(CacheDBInterface):
host,
port,
lock_name="default_lock",
log_key="usage_logs",
username=None,
password=None,
timeout=240,
blocking_timeout=300,
connection_timeout=30,
):
super().__init__(host, port, lock_name)
super().__init__(host, port, lock_name, log_key)
self.host = host
self.port = port
@ -177,6 +178,64 @@ class RedisAdapter(CacheDBInterface):
entries = await self.async_redis.lrange(session_key, 0, -1)
return [json.loads(e) for e in entries]
async def log_usage(
self,
user_id: str,
log_entry: dict,
ttl: int | None = 604800,
):
"""
Log usage information (API endpoint calls, MCP tool invocations) to Redis.
Args:
user_id: The user ID.
log_entry: Dictionary containing usage log information.
ttl: Optional time-to-live (seconds). If provided, the log list expires after this time.
Raises:
CacheConnectionError: If Redis connection fails or times out.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
await self.async_redis.rpush(usage_logs_key, json.dumps(log_entry))
if ttl is not None:
await self.async_redis.expire(usage_logs_key, ttl)
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while logging usage: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while logging usage to Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def get_usage_logs(self, user_id: str, limit: int = 100):
"""
Retrieve usage logs for a given user.
Args:
user_id: The user ID.
limit: Maximum number of logs to retrieve (default: 100).
Returns:
List of usage log entries, most recent first.
"""
try:
usage_logs_key = f"{self.log_key}:{user_id}"
entries = await self.async_redis.lrange(usage_logs_key, -limit, -1)
return [json.loads(e) for e in reversed(entries)] if entries else []
except (redis.ConnectionError, redis.TimeoutError) as e:
error_msg = f"Redis connection error while retrieving usage logs: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
except Exception as e:
error_msg = f"Unexpected error while retrieving usage logs from Redis: {str(e)}"
logger.error(error_msg)
raise CacheConnectionError(error_msg) from e
async def close(self):
"""
Gracefully close the async Redis connection.

View file

@ -24,7 +24,6 @@ async def get_graph_engine() -> GraphDBInterface:
return graph_client
@lru_cache
def create_graph_engine(
graph_database_provider,
graph_file_path,
@ -35,6 +34,35 @@ def create_graph_engine(
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Wrapper function to call create graph engine with caching.
For a detailed description, see _create_graph_engine.
"""
return _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url,
graph_database_name,
graph_database_username,
graph_database_password,
graph_database_port,
graph_database_key,
graph_dataset_database_handler,
)
@lru_cache
def _create_graph_engine(
graph_database_provider,
graph_file_path,
graph_database_url="",
graph_database_name="",
graph_database_username="",
graph_database_password="",
graph_database_port="",
graph_database_key="",
graph_dataset_database_handler="",
):
"""
Create a graph engine based on the specified provider type.

View file

@ -236,6 +236,7 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = None,
with_vector: bool = False,
include_payload: bool = False, # TODO: Add support for this parameter
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -319,7 +320,12 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
self._na_exception_handler(e, query_string)
async def batch_search(
self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False
self,
collection_name: str,
query_texts: List[str],
limit: int,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -342,7 +348,14 @@ class NeptuneAnalyticsAdapter(NeptuneGraphDB, VectorDBInterface):
data_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(collection_name, None, vector, limit, with_vectors)
self.search(
collection_name,
None,
vector,
limit,
with_vectors,
include_payload=include_payload,
)
for vector in data_vectors
]
)

View file

@ -355,6 +355,7 @@ class ChromaDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False, # TODO: Add support for this parameter when set to False
):
"""
Search for items in a collection using either a text or a vector query.
@ -441,6 +442,7 @@ class ChromaDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: int = 5,
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform multiple searches in a single request for efficiency, returning results for each

View file

@ -7,7 +7,6 @@ from cognee.infrastructure.databases.graph.config import get_graph_context_confi
from functools import lru_cache
@lru_cache
def create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
@ -15,6 +14,29 @@ def create_vector_engine(
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Wrapper function to call create vector engine with caching.
For a detailed description, see _create_vector_engine.
"""
return _create_vector_engine(
vector_db_provider,
vector_db_url,
vector_db_name,
vector_db_port,
vector_db_key,
vector_dataset_database_handler,
)
@lru_cache
def _create_vector_engine(
vector_db_provider: str,
vector_db_url: str,
vector_db_name: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_dataset_database_handler: str = "",
):
"""
Create a vector database engine based on the specified provider.

View file

@ -231,6 +231,7 @@ class LanceDBAdapter(VectorDBInterface):
limit: Optional[int] = 15,
with_vector: bool = False,
normalized: bool = True,
include_payload: bool = False,
):
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -247,17 +248,27 @@ class LanceDBAdapter(VectorDBInterface):
if limit <= 0:
return []
result_values = await collection.vector_search(query_vector).limit(limit).to_list()
# Note: Exclude payload if not needed to optimize performance
select_columns = (
["id", "vector", "payload", "_distance"]
if include_payload
else ["id", "vector", "_distance"]
)
result_values = (
await collection.vector_search(query_vector)
.select(select_columns)
.limit(limit)
.to_list()
)
if not result_values:
return []
normalized_values = normalize_distances(result_values)
return [
ScoredResult(
id=parse_id(result["id"]),
payload=result["payload"],
payload=result["payload"] if include_payload else None,
score=normalized_values[value_index],
)
for value_index, result in enumerate(result_values)
@ -269,6 +280,7 @@ class LanceDBAdapter(VectorDBInterface):
query_texts: List[str],
limit: Optional[int] = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -279,6 +291,7 @@ class LanceDBAdapter(VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional
from uuid import UUID
from pydantic import BaseModel
@ -12,10 +12,10 @@ class ScoredResult(BaseModel):
- id (UUID): Unique identifier for the scored result.
- score (float): The score associated with the result, where a lower score indicates a
better outcome.
- payload (Dict[str, Any]): Additional information related to the score, stored as
- payload (Optional[Dict[str, Any]]): Additional information related to the score, stored as
key-value pairs in a dictionary.
"""
id: UUID
score: float # Lower score is better
payload: Dict[str, Any]
payload: Optional[Dict[str, Any]] = None

View file

@ -301,6 +301,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
include_payload: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
@ -324,10 +325,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Note: Exclude payload from returned columns if not needed to optimize performance
select_columns = (
[PGVectorDataPoint]
if include_payload
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
)
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
*select_columns,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
@ -344,7 +351,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"payload": vector.payload if include_payload else None,
"_distance": vector.similarity,
}
)
@ -359,7 +366,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
ScoredResult(
id=row.get("id"),
payload=row.get("payload") if include_payload else None,
score=row.get("score"),
)
for row in vector_list
]
@ -369,6 +380,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -379,6 +391,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]

View file

@ -87,6 +87,7 @@ class VectorDBInterface(Protocol):
query_vector: Optional[List[float]],
limit: Optional[int],
with_vector: bool = False,
include_payload: bool = False,
):
"""
Perform a search in the specified collection using either a text query or a vector
@ -103,6 +104,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return from the search.
- with_vector (bool): Whether to return the vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError
@ -113,6 +117,7 @@ class VectorDBInterface(Protocol):
query_texts: List[str],
limit: Optional[int],
with_vectors: bool = False,
include_payload: bool = False,
):
"""
Perform a batch search using multiple text queries against a collection.
@ -125,6 +130,9 @@ class VectorDBInterface(Protocol):
- limit (Optional[int]): The maximum number of results to return for each query.
- with_vectors (bool): Whether to include vector representations with search
results. (default False)
- include_payload (bool): Whether to include the payload data with search. Search is faster when set to False.
Payload contains metadata about the data point, useful for searches that are only based on embedding distances
like the RAG_COMPLETION search type, but not needed when search also contains graph data.
"""
raise NotImplementedError

View file

@ -1,5 +1,6 @@
import time
from cognee.shared.logging_utils import get_logger
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from typing import List, Dict, Union, Optional, Type, Iterable, Tuple, Callable, Any
from cognee.modules.graph.exceptions import (
@ -44,6 +45,12 @@ class CogneeGraph(CogneeAbstractGraph):
def add_edge(self, edge: Edge) -> None:
self.edges.append(edge)
edge_text = edge.attributes.get("edge_text") or edge.attributes.get("relationship_type")
edge.attributes["edge_type_id"] = (
generate_edge_id(edge_id=edge_text) if edge_text else None
) # Update edge with generated edge_type_id
edge.node1.add_skeleton_edge(edge)
edge.node2.add_skeleton_edge(edge)
key = edge.get_distance_key()
@ -284,13 +291,7 @@ class CogneeGraph(CogneeAbstractGraph):
for query_index, scored_results in enumerate(per_query_scored_results):
for result in scored_results:
payload = getattr(result, "payload", None)
if not isinstance(payload, dict):
continue
text = payload.get("text")
if not text:
continue
matching_edges = self.edges_by_distance_key.get(str(text))
matching_edges = self.edges_by_distance_key.get(str(result.id))
if not matching_edges:
continue
for edge in matching_edges:

View file

@ -141,7 +141,7 @@ class Edge:
self.status = np.ones(dimension, dtype=int)
def get_distance_key(self) -> Optional[str]:
key = self.attributes.get("edge_text") or self.attributes.get("relationship_type")
key = self.attributes.get("edge_type_id")
if key is None:
return None
return str(key)

View file

@ -47,7 +47,9 @@ class ChunksRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(found_chunks)} chunks from vector search")
await update_node_access_timestamps(found_chunks)

View file

@ -62,7 +62,9 @@ class CompletionRetriever(BaseRetriever):
vector_engine = get_vector_engine()
try:
found_chunks = await vector_engine.search("DocumentChunk_text", query, limit=self.top_k)
found_chunks = await vector_engine.search(
"DocumentChunk_text", query, limit=self.top_k, include_payload=True
)
if len(found_chunks) == 0:
return ""

View file

@ -52,7 +52,7 @@ class SummariesRetriever(BaseRetriever):
try:
summaries_results = await vector_engine.search(
"TextSummary_text", query, limit=self.top_k
"TextSummary_text", query, limit=self.top_k, include_payload=True
)
logger.info(f"Found {len(summaries_results)} summaries from vector search")

View file

@ -98,7 +98,7 @@ class TemporalRetriever(GraphCompletionRetriever):
async def filter_top_k_events(self, relevant_events, scored_results):
# Build a score lookup from vector search results
score_lookup = {res.payload["id"]: res.score for res in scored_results}
score_lookup = {res.id: res.score for res in scored_results}
events_with_scores = []
for event in relevant_events[0]["events"]:

View file

@ -67,7 +67,9 @@ class TripletRetriever(BaseRetriever):
"In order to use TRIPLET_COMPLETION first use the create_triplet_embeddings memify pipeline. "
)
found_triplets = await vector_engine.search("Triplet_text", query, limit=self.top_k)
found_triplets = await vector_engine.search(
"Triplet_text", query, limit=self.top_k, include_payload=True
)
if len(found_triplets) == 0:
return ""

View file

@ -4,6 +4,4 @@ Custom exceptions for the Cognee API.
This module defines a set of exceptions for handling various shared utility errors
"""
from .exceptions import (
IngestionError,
)
from .exceptions import IngestionError, UsageLoggerError

View file

@ -1,4 +1,4 @@
from cognee.exceptions import CogneeValidationError
from cognee.exceptions import CogneeConfigurationError, CogneeValidationError
from fastapi import status
@ -10,3 +10,13 @@ class IngestionError(CogneeValidationError):
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
):
super().__init__(message, name, status_code)
class UsageLoggerError(CogneeConfigurationError):
def __init__(
self,
message: str = "Usage logging configuration is invalid.",
name: str = "UsageLoggerError",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
):
super().__init__(message, name, status_code)

View file

@ -0,0 +1,332 @@
import asyncio
import inspect
import os
from datetime import datetime, timezone
from functools import singledispatch, wraps
from typing import Any, Callable, Optional
from uuid import UUID
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import get_cache_engine
from cognee.shared.exceptions import UsageLoggerError
from cognee.shared.logging_utils import get_logger
from cognee import __version__ as cognee_version
logger = get_logger("usage_logger")
@singledispatch
def _sanitize_value(value: Any) -> Any:
"""Default handler for JSON serialization - converts to string."""
try:
str_repr = str(value)
if str_repr.startswith("<") and str_repr.endswith(">"):
return f"<cannot be serialized: {type(value).__name__}>"
return str_repr
except Exception:
return f"<cannot be serialized: {type(value).__name__}>"
@_sanitize_value.register(type(None))
def _(value: None) -> None:
"""Handle None values - returns None as-is."""
return None
@_sanitize_value.register(str)
@_sanitize_value.register(int)
@_sanitize_value.register(float)
@_sanitize_value.register(bool)
def _(value: str | int | float | bool) -> str | int | float | bool:
"""Handle primitive types - returns value as-is since they're JSON-serializable."""
return value
@_sanitize_value.register(UUID)
def _(value: UUID) -> str:
"""Convert UUID to string representation."""
return str(value)
@_sanitize_value.register(datetime)
def _(value: datetime) -> str:
"""Convert datetime to ISO format string."""
return value.isoformat()
@_sanitize_value.register(list)
@_sanitize_value.register(tuple)
def _(value: list | tuple) -> list:
"""Recursively sanitize list or tuple elements."""
return [_sanitize_value(v) for v in value]
@_sanitize_value.register(dict)
def _(value: dict) -> dict:
"""Recursively sanitize dictionary keys and values."""
sanitized = {}
for k, v in value.items():
key_str = k if isinstance(k, str) else _sanitize_dict_key(k)
sanitized[key_str] = _sanitize_value(v)
return sanitized
def _sanitize_dict_key(key: Any) -> str:
"""Convert a non-string dict key to a string."""
sanitized_key = _sanitize_value(key)
if isinstance(sanitized_key, str):
if sanitized_key.startswith("<cannot be serialized"):
return f"<key:{type(key).__name__}>"
return sanitized_key
return str(sanitized_key)
def _get_param_names(func: Callable) -> list[str]:
"""Get parameter names from function signature."""
try:
return list(inspect.signature(func).parameters.keys())
except Exception:
return []
def _get_param_defaults(func: Callable) -> dict[str, Any]:
"""Get parameter defaults from function signature."""
try:
sig = inspect.signature(func)
defaults = {}
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
defaults[param_name] = param.default
return defaults
except Exception:
return {}
def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]:
"""Extract user_id from function arguments if available."""
try:
if "user" in kwargs and kwargs["user"] is not None:
user = kwargs["user"]
if hasattr(user, "id"):
return str(user.id)
for i, param_name in enumerate(param_names):
if i < len(args) and param_name == "user":
user = args[i]
if user is not None and hasattr(user, "id"):
return str(user.id)
return None
except Exception:
return None
def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str], func: Callable) -> dict:
"""Extract function parameters - captures all parameters including defaults, sanitizes for JSON."""
params = {}
for key, value in kwargs.items():
if key != "user":
params[key] = _sanitize_value(value)
if param_names:
for i, param_name in enumerate(param_names):
if i < len(args) and param_name != "user" and param_name not in kwargs:
params[param_name] = _sanitize_value(args[i])
else:
for i, arg_value in enumerate(args):
params[f"arg_{i}"] = _sanitize_value(arg_value)
if param_names:
defaults = _get_param_defaults(func)
for param_name in param_names:
if param_name != "user" and param_name not in params and param_name in defaults:
params[param_name] = _sanitize_value(defaults[param_name])
return params
async def _log_usage_async(
function_name: str,
log_type: str,
user_id: Optional[str],
parameters: dict,
result: Any,
success: bool,
error: Optional[str],
duration_ms: float,
start_time: datetime,
end_time: datetime,
):
"""Asynchronously log function usage to Redis.
Args:
function_name: Name of the function being logged.
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool", "function").
user_id: User identifier, or None to use "unknown".
parameters: Dictionary of function parameters (sanitized).
result: Function return value (will be sanitized).
success: Whether the function executed successfully.
error: Error message if function failed, None otherwise.
duration_ms: Execution duration in milliseconds.
start_time: Function start timestamp.
end_time: Function end timestamp.
Note:
This function silently handles errors to avoid disrupting the original
function execution. Logs are written to Redis with TTL from config.
"""
try:
logger.debug(f"Starting to log usage for {function_name} at {start_time.isoformat()}")
config = get_cache_config()
if not config.usage_logging:
logger.debug("Usage logging disabled, skipping log")
return
logger.debug(f"Getting cache engine for {function_name}")
cache_engine = get_cache_engine()
if cache_engine is None:
logger.warning(
f"Cache engine not available for usage logging (function: {function_name})"
)
return
logger.debug(f"Cache engine obtained for {function_name}")
if user_id is None:
user_id = "unknown"
logger.debug(f"No user_id provided, using 'unknown' for {function_name}")
log_entry = {
"timestamp": start_time.isoformat(),
"type": log_type,
"function_name": function_name,
"user_id": user_id,
"parameters": parameters,
"result": _sanitize_value(result),
"success": success,
"error": error,
"duration_ms": round(duration_ms, 2),
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"metadata": {
"cognee_version": cognee_version,
"environment": os.getenv("ENV", "prod"),
},
}
logger.debug(f"Calling log_usage for {function_name}, user_id={user_id}")
await cache_engine.log_usage(
user_id=user_id,
log_entry=log_entry,
ttl=config.usage_logging_ttl,
)
logger.info(f"Successfully logged usage for {function_name} (user_id={user_id})")
except Exception as e:
logger.error(f"Failed to log usage for {function_name}: {str(e)}", exc_info=True)
def log_usage(function_name: Optional[str] = None, log_type: str = "function"):
"""
Decorator to log function usage to Redis.
This decorator is completely transparent - it doesn't change function behavior.
It logs function name, parameters, result, timing, and user (if available).
Args:
function_name: Optional name for the function (defaults to func.__name__)
log_type: Type of log entry (e.g., "api_endpoint", "mcp_tool")
Usage:
@log_usage(function_name="MCP my_mcp_tool", log_type="mcp_tool")
async def my_mcp_tool(...):
# mcp code
@log_usage(function_name="POST API /v1/add", log_type="api_endpoint")
async def add(...):
# endpoint code
"""
def decorator(func: Callable) -> Callable:
"""Inner decorator that wraps the function with usage logging.
Args:
func: The async function to wrap with usage logging.
Returns:
Callable: The wrapped function with usage logging enabled.
Raises:
UsageLoggerError: If the function is not async.
"""
if not inspect.iscoroutinefunction(func):
raise UsageLoggerError(
f"@log_usage requires an async function. Got {func.__name__} which is not async."
)
@wraps(func)
async def async_wrapper(*args, **kwargs):
"""Wrapper function that executes the original function and logs usage.
This wrapper:
- Extracts user ID and parameters from function arguments
- Executes the original function
- Captures result, success status, and any errors
- Logs usage information asynchronously without blocking
Args:
*args: Positional arguments passed to the original function.
**kwargs: Keyword arguments passed to the original function.
Returns:
Any: The return value of the original function.
Raises:
Any exception raised by the original function (re-raised after logging).
"""
config = get_cache_config()
if not config.usage_logging:
return await func(*args, **kwargs)
start_time = datetime.now(timezone.utc)
param_names = _get_param_names(func)
user_id = _extract_user_id(args, kwargs, param_names)
parameters = _extract_parameters(args, kwargs, param_names, func)
result = None
success = True
error = None
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
success = False
error = str(e)
raise
finally:
end_time = datetime.now(timezone.utc)
duration_ms = (end_time - start_time).total_seconds() * 1000
try:
await _log_usage_async(
function_name=function_name or func.__name__,
log_type=log_type,
user_id=user_id,
parameters=parameters,
result=result,
success=success,
error=error,
duration_ms=duration_ms,
start_time=start_time,
end_time=end_time,
)
except Exception as e:
logger.error(
f"Failed to log usage for {function_name or func.__name__}: {str(e)}",
exc_info=True,
)
return async_wrapper
return decorator

View file

@ -0,0 +1,255 @@
"""Integration tests for usage logger with real Redis components."""
import os
import pytest
import asyncio
from datetime import datetime, timezone
from types import SimpleNamespace
from uuid import UUID
from unittest.mock import patch
from cognee.shared.usage_logger import log_usage
from cognee.infrastructure.databases.cache.config import get_cache_config
from cognee.infrastructure.databases.cache.get_cache_engine import (
get_cache_engine,
create_cache_engine,
)
@pytest.fixture
def usage_logging_config():
"""Fixture to enable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def usage_logging_disabled():
"""Fixture to disable usage logging via environment variables."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "false"
os.environ["CACHE_BACKEND"] = "redis"
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
yield
os.environ.clear()
os.environ.update(original_env)
get_cache_config.cache_clear()
create_cache_engine.cache_clear()
@pytest.fixture
def redis_adapter():
"""Real RedisAdapter instance for testing."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
try:
yield RedisAdapter(host="localhost", port=6379, log_key="test_usage_logs")
except Exception as e:
pytest.skip(f"Redis not available: {e}")
@pytest.fixture
def test_user():
"""Test user object."""
return SimpleNamespace(id="test-user-123")
class TestDecoratorBehavior:
"""Test decorator behavior with real components."""
@pytest.mark.asyncio
async def test_decorator_configuration(
self, usage_logging_disabled, usage_logging_config, redis_adapter
):
"""Test decorator skips when disabled and logs when enabled."""
# Test disabled
call_count = 0
@log_usage(function_name="test_func", log_type="test")
async def test_func():
nonlocal call_count
call_count += 1
return "result"
assert await test_func() == "result"
assert call_count == 1
# Test enabled with cache engine None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = None
assert await test_func() == "result"
@pytest.mark.asyncio
async def test_decorator_logging(self, usage_logging_config, redis_adapter, test_user):
"""Test decorator logs to Redis with correct structure."""
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42, user=None):
await asyncio.sleep(0.01)
return {"result": f"{param1}_{param2}"}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
result = await test_func("value1", user=test_user)
assert result == {"result": "value1_42"}
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
log = logs[0]
assert log["function_name"] == "test_func"
assert log["type"] == "test"
assert log["user_id"] == "test-user-123"
assert log["parameters"]["param1"] == "value1"
assert log["parameters"]["param2"] == 42
assert log["success"] is True
assert all(
field in log
for field in [
"timestamp",
"result",
"error",
"duration_ms",
"start_time",
"end_time",
"metadata",
]
)
assert "cognee_version" in log["metadata"]
@pytest.mark.asyncio
async def test_multiple_calls(self, usage_logging_config, redis_adapter, test_user):
"""Test multiple consecutive calls are all logged."""
@log_usage(function_name="multi_test", log_type="test")
async def multi_func(call_num: int, user=None):
return {"call": call_num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
for i in range(3):
await multi_func(i, user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert len(logs) >= 3
call_nums = {log["parameters"]["call_num"] for log in logs[:3]}
assert call_nums == {0, 1, 2}
class TestRealRedisIntegration:
"""Test real Redis integration."""
@pytest.mark.asyncio
async def test_redis_storage_retrieval_and_ttl(
self, usage_logging_config, redis_adapter, test_user
):
"""Test logs are stored, retrieved with correct order/limits, and TTL is set."""
@log_usage(function_name="redis_test", log_type="test")
async def redis_func(data: str, user=None):
return {"processed": data}
@log_usage(function_name="order_test", log_type="test")
async def order_func(num: int, user=None):
return {"num": num}
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# Storage
await redis_func("test_data", user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["function_name"] == "redis_test"
assert logs[0]["parameters"]["data"] == "test_data"
# Order (most recent first)
for i in range(3):
await order_func(i, user=test_user)
await asyncio.sleep(0.01)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert [log["parameters"]["num"] for log in logs[:3]] == [2, 1, 0]
# Limit
assert len(await redis_adapter.get_usage_logs("test-user-123", limit=2)) == 2
# TTL
ttl = await redis_adapter.async_redis.ttl("test_usage_logs:test-user-123")
assert 0 < ttl <= 604800
class TestEdgeCases:
"""Test edge cases in integration tests."""
@pytest.mark.asyncio
async def test_edge_cases(self, usage_logging_config, redis_adapter, test_user):
"""Test no params, defaults, complex structures, exceptions, None, circular refs."""
@log_usage(function_name="no_params", log_type="test")
async def no_params_func(user=None):
return "result"
@log_usage(function_name="defaults_only", log_type="test")
async def defaults_only_func(param1: str = "default1", param2: int = 42, user=None):
return {"param1": param1, "param2": param2}
@log_usage(function_name="complex_test", log_type="test")
async def complex_func(user=None):
return {
"nested": {
"list": [1, 2, 3],
"uuid": UUID("123e4567-e89b-12d3-a456-426614174000"),
"datetime": datetime(2024, 1, 15, tzinfo=timezone.utc),
}
}
@log_usage(function_name="exception_test", log_type="test")
async def exception_func(user=None):
raise RuntimeError("Test exception")
@log_usage(function_name="none_test", log_type="test")
async def none_func(user=None):
return None
with patch("cognee.shared.usage_logger.get_cache_engine") as mock_get:
mock_get.return_value = redis_adapter
# No parameters
await no_params_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"] == {}
# Default parameters
await defaults_only_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["parameters"]["param1"] == "default1"
assert logs[0]["parameters"]["param2"] == 42
# Complex nested structures
await complex_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert isinstance(logs[0]["result"]["nested"]["uuid"], str)
assert isinstance(logs[0]["result"]["nested"]["datetime"], str)
# Exception handling
with pytest.raises(RuntimeError):
await exception_func(user=test_user)
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["success"] is False
assert "Test exception" in logs[0]["error"]
# None return value
assert await none_func(user=test_user) is None
logs = await redis_adapter.get_usage_logs("test-user-123", limit=10)
assert logs[0]["result"] is None

View file

@ -97,7 +97,7 @@ async def test_vector_engine_search_none_limit():
query_vector = (await vector_engine.embedding_engine.embed_text([query_text]))[0]
result = await vector_engine.search(
collection_name=collection_name, query_vector=query_vector, limit=None
collection_name=collection_name, query_vector=query_vector, limit=None, include_payload=True
)
# Check that we did not accidentally use any default value for limit

View file

@ -70,7 +70,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -149,7 +149,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,7 +48,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -63,7 +63,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -52,7 +52,9 @@ async def main():
await cognee.cognify([dataset_name])
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -41,14 +41,14 @@ async def _reset_engines_and_prune() -> None:
except Exception:
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()

View file

@ -163,7 +163,9 @@ async def main():
await test_getting_of_documents(dataset_name_1)
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -58,7 +58,9 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "Quantum computer"))[0]
random_node = (
await vector_engine.search("Entity_name", "Quantum computer", include_payload=True)
)[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -43,7 +43,7 @@ async def main():
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
random_node = (await vector_engine.search("Entity_name", "AI"))[0]
random_node = (await vector_engine.search("Entity_name", "AI", include_payload=True))[0]
random_node_name = random_node.payload["text"]
search_results = await cognee.search(

View file

@ -48,14 +48,14 @@ async def _reset_engines_and_prune() -> None:
# Engine might not exist yet
pass
from cognee.infrastructure.databases.graph.get_graph_engine import create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
from cognee.infrastructure.databases.graph.get_graph_engine import _create_graph_engine
from cognee.infrastructure.databases.vector.create_vector_engine import _create_vector_engine
from cognee.infrastructure.databases.relational.create_relational_engine import (
create_relational_engine,
)
create_graph_engine.cache_clear()
create_vector_engine.cache_clear()
_create_graph_engine.cache_clear()
_create_vector_engine.cache_clear()
create_relational_engine.cache_clear()
await cognee.prune.prune_data()

View file

@ -0,0 +1,268 @@
import os
import pytest
import pytest_asyncio
import asyncio
from fastapi.testclient import TestClient
import cognee
from cognee.api.client import app
from cognee.modules.users.methods import get_default_user, get_authenticated_user
async def _reset_engines_and_prune():
"""Reset db engine caches and prune data/system."""
try:
from cognee.infrastructure.databases.vector import get_vector_engine
vector_engine = get_vector_engine()
if hasattr(vector_engine, "engine") and hasattr(vector_engine.engine, "dispose"):
await vector_engine.engine.dispose(close=True)
except Exception:
pass
await cognee.prune.prune_data()
await cognee.prune.prune_system(metadata=True)
@pytest.fixture(scope="session")
def event_loop():
"""Use a single asyncio event loop for this test module."""
loop = asyncio.new_event_loop()
try:
yield loop
finally:
loop.close()
@pytest.fixture(scope="session")
def e2e_config():
"""Configure environment for E2E tests."""
original_env = os.environ.copy()
os.environ["USAGE_LOGGING"] = "true"
os.environ["CACHE_BACKEND"] = "redis"
os.environ["CACHE_HOST"] = "localhost"
os.environ["CACHE_PORT"] = "6379"
yield
os.environ.clear()
os.environ.update(original_env)
@pytest.fixture(scope="session")
def authenticated_client(test_client):
"""Override authentication to use default user."""
async def override_get_authenticated_user():
return await get_default_user()
app.dependency_overrides[get_authenticated_user] = override_get_authenticated_user
yield test_client
app.dependency_overrides.pop(get_authenticated_user, None)
@pytest_asyncio.fixture(scope="session")
async def test_data_setup():
"""Set up test data: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_e2e_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest_asyncio.fixture
async def mcp_data_setup():
"""Set up test data for MCP tests: prune first, then add file and cognify."""
await _reset_engines_and_prune()
dataset_name = "test_mcp_dataset"
test_text = "Germany is located in Europe right next to the Netherlands."
await cognee.add(test_text, dataset_name)
await cognee.cognify([dataset_name])
yield dataset_name
await _reset_engines_and_prune()
@pytest.fixture(scope="session")
def test_client():
"""TestClient instance for API calls."""
with TestClient(app) as client:
yield client
@pytest_asyncio.fixture
async def cache_engine(e2e_config):
"""Get cache engine for log verification in test's event loop."""
from cognee.infrastructure.databases.cache.redis.RedisAdapter import RedisAdapter
from cognee.infrastructure.databases.cache.config import get_cache_config
config = get_cache_config()
if not config.usage_logging or config.cache_backend != "redis":
pytest.skip("Redis usage logging not configured")
engine = RedisAdapter(
host=config.cache_host,
port=config.cache_port,
username=config.cache_username,
password=config.cache_password,
log_key="usage_logs",
)
return engine
@pytest.mark.asyncio
async def test_api_endpoint_logging(e2e_config, authenticated_client, cache_engine):
"""Test that API endpoints succeed and log to Redis."""
user = await get_default_user()
dataset_name = "test_e2e_api_dataset"
add_response = authenticated_client.post(
"/api/v1/add",
data={"datasetName": dataset_name},
files=[
(
"data",
(
"test.txt",
b"Germany is located in Europe right next to the Netherlands.",
"text/plain",
),
)
],
)
assert add_response.status_code in [200, 201], f"Add endpoint failed: {add_response.text}"
cognify_response = authenticated_client.post(
"/api/v1/cognify",
json={"datasets": [dataset_name], "run_in_background": False},
)
assert cognify_response.status_code in [200, 201], (
f"Cognify endpoint failed: {cognify_response.text}"
)
search_response = authenticated_client.post(
"/api/v1/search",
json={"query": "Germany", "search_type": "GRAPH_COMPLETION", "datasets": [dataset_name]},
)
assert search_response.status_code == 200, f"Search endpoint failed: {search_response.text}"
logs = await cache_engine.get_usage_logs(str(user.id), limit=20)
add_logs = [log for log in logs if log.get("function_name") == "POST /v1/add"]
assert len(add_logs) > 0
assert add_logs[0]["type"] == "api_endpoint"
assert add_logs[0]["user_id"] == str(user.id)
assert add_logs[0]["success"] is True
cognify_logs = [log for log in logs if log.get("function_name") == "POST /v1/cognify"]
assert len(cognify_logs) > 0
assert cognify_logs[0]["type"] == "api_endpoint"
assert cognify_logs[0]["user_id"] == str(user.id)
assert cognify_logs[0]["success"] is True
search_logs = [log for log in logs if log.get("function_name") == "POST /v1/search"]
assert len(search_logs) > 0
assert search_logs[0]["type"] == "api_endpoint"
assert search_logs[0]["user_id"] == str(user.id)
assert search_logs[0]["success"] is True
@pytest.mark.asyncio
async def test_mcp_tool_logging(e2e_config, cache_engine):
"""Test that MCP tools succeed and log to Redis."""
import sys
import importlib.util
from pathlib import Path
await _reset_engines_and_prune()
repo_root = Path(__file__).parent.parent.parent
mcp_src_path = repo_root / "cognee-mcp" / "src"
mcp_server_path = mcp_src_path / "server.py"
if not mcp_server_path.exists():
pytest.skip(f"MCP server not found at {mcp_server_path}")
if str(mcp_src_path) not in sys.path:
sys.path.insert(0, str(mcp_src_path))
spec = importlib.util.spec_from_file_location("mcp_server_module", mcp_server_path)
mcp_server_module = importlib.util.module_from_spec(spec)
import os
original_cwd = os.getcwd()
try:
os.chdir(str(mcp_src_path))
spec.loader.exec_module(mcp_server_module)
finally:
os.chdir(original_cwd)
if mcp_server_module.cognee_client is None:
cognee_client_path = mcp_src_path / "cognee_client.py"
if cognee_client_path.exists():
spec_client = importlib.util.spec_from_file_location(
"cognee_client", cognee_client_path
)
cognee_client_module = importlib.util.module_from_spec(spec_client)
spec_client.loader.exec_module(cognee_client_module)
CogneeClient = cognee_client_module.CogneeClient
mcp_server_module.cognee_client = CogneeClient()
else:
pytest.skip(f"CogneeClient not found at {cognee_client_path}")
test_text = "Germany is located in Europe right next to the Netherlands."
await mcp_server_module.cognify(data=test_text)
await asyncio.sleep(30.0)
list_result = await mcp_server_module.list_data()
assert list_result is not None, "List data should return results"
search_result = await mcp_server_module.search(
search_query="Germany", search_type="GRAPH_COMPLETION", top_k=5
)
assert search_result is not None, "Search should return results"
interaction_data = "User: What is Germany?\nAgent: Germany is a country in Europe."
await mcp_server_module.save_interaction(data=interaction_data)
await asyncio.sleep(30.0)
status_result = await mcp_server_module.cognify_status()
assert status_result is not None, "Cognify status should return results"
await mcp_server_module.prune()
await asyncio.sleep(0.5)
logs = await cache_engine.get_usage_logs("unknown", limit=50)
mcp_logs = [log for log in logs if log.get("type") == "mcp_tool"]
assert len(mcp_logs) > 0, (
f"Should have MCP tool logs with user_id='unknown'. Found logs: {[log.get('function_name') for log in logs[:5]]}"
)
assert len(mcp_logs) == 6
function_names = [log.get("function_name") for log in mcp_logs]
expected_tools = [
"MCP cognify",
"MCP list_data",
"MCP search",
"MCP save_interaction",
"MCP cognify_status",
"MCP prune",
]
for expected_tool in expected_tools:
assert expected_tool in function_names, (
f"Should have {expected_tool} log. Found: {function_names}"
)
for log in mcp_logs:
assert log["type"] == "mcp_tool"
assert log["user_id"] == "unknown"
assert log["success"] is True

View file

@ -62,6 +62,8 @@ def test_cache_config_to_dict():
"cache_password": None,
"agentic_lock_expire": 100,
"agentic_lock_timeout": 200,
"usage_logging": False,
"usage_logging_ttl": 604800,
}

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.exceptions import EntityNotFoundError, EntityAlreadyExistsError
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.cognee_graph.CogneeGraphElements import Edge, Node
@ -379,7 +380,7 @@ async def test_map_vector_distances_to_graph_edges_with_payload(setup_graph):
graph.add_edge(edge)
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id("CONNECTS_TO"), 0.92, payload={"text": "CONNECTS_TO"}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -404,8 +405,9 @@ async def test_map_vector_distances_partial_edge_coverage(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "CONNECTS_TO"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "CONNECTS_TO"}),
MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -431,8 +433,9 @@ async def test_map_vector_distances_edges_fallback_to_relationship_type(setup_gr
)
graph.add_edge(edge)
edge_text = "KNOWS"
edge_distances = [
MockScoredResult("e1", 0.85, payload={"text": "KNOWS"}),
MockScoredResult(generate_edge_id(edge_text), 0.85, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -457,8 +460,9 @@ async def test_map_vector_distances_no_edge_matches(setup_graph):
)
graph.add_edge(edge)
edge_text = "SOME_OTHER_EDGE"
edge_distances = [
MockScoredResult("e1", 0.92, payload={"text": "SOME_OTHER_EDGE"}),
MockScoredResult(generate_edge_id(edge_text), 0.92, payload={"text": edge_text}),
]
await graph.map_vector_distances_to_graph_edges(edge_distances=edge_distances)
@ -511,9 +515,15 @@ async def test_map_vector_distances_to_graph_edges_multi_query(setup_graph):
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_2_text = "B"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0
[MockScoredResult("e2", 0.2, payload={"text": "B"})], # query 1
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0
[
MockScoredResult(generate_edge_id(edge_2_text), 0.2, payload={"text": edge_2_text})
], # query 1
]
await graph.map_vector_distances_to_graph_edges(
@ -541,8 +551,11 @@ async def test_map_vector_distances_to_graph_edges_preserves_unmapped_indices(se
graph.add_edge(edge1)
graph.add_edge(edge2)
edge_1_text = "A"
edge_distances = [
[MockScoredResult("e1", 0.1, payload={"text": "A"})], # query 0: only edge1 mapped
[
MockScoredResult(generate_edge_id(edge_1_text), 0.1, payload={"text": edge_1_text})
], # query 0: only edge1 mapped
[], # query 1: no edges mapped
]

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "Steve Rodger"
assert context[1]["text"] == "Mike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -33,7 +33,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Steve Rodger\nMike Broski"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio
@ -85,7 +87,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Chunk 0\nChunk 1"
mock_vector_engine.search.assert_awaited_once_with("DocumentChunk_text", "test query", limit=2)
mock_vector_engine.search.assert_awaited_once_with(
"DocumentChunk_text", "test query", limit=2, include_payload=True
)
@pytest.mark.asyncio

View file

@ -35,7 +35,9 @@ async def test_get_context_success(mock_vector_engine):
assert len(context) == 2
assert context[0]["text"] == "S.R."
assert context[1]["text"] == "M.B."
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio
@ -87,7 +89,9 @@ async def test_get_context_top_k_limit(mock_vector_engine):
context = await retriever.get_context("test query")
assert len(context) == 3
mock_vector_engine.search.assert_awaited_once_with("TextSummary_text", "test query", limit=3)
mock_vector_engine.search.assert_awaited_once_with(
"TextSummary_text", "test query", limit=3, include_payload=True
)
@pytest.mark.asyncio

View file

@ -63,8 +63,8 @@ async def test_filter_top_k_events_sorts_and_limits():
]
scored_results = [
SimpleNamespace(payload={"id": "e2"}, score=0.10),
SimpleNamespace(payload={"id": "e1"}, score=0.20),
SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.10),
SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.20),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -91,8 +91,8 @@ async def test_filter_top_k_events_includes_unknown_as_infinite_but_not_in_top_k
]
scored_results = [
SimpleNamespace(payload={"id": "known2"}, score=0.05),
SimpleNamespace(payload={"id": "known1"}, score=0.50),
SimpleNamespace(id="known2", payload={"id": "known2"}, score=0.05),
SimpleNamespace(id="known1", payload={"id": "known1"}, score=0.50),
]
top = await tr.filter_top_k_events(relevant_events, scored_results)
@ -119,8 +119,8 @@ async def test_filter_top_k_events_limits_when_top_k_exceeds_events():
tr = TemporalRetriever(top_k=10)
relevant_events = [{"events": [{"id": "a"}, {"id": "b"}]}]
scored_results = [
SimpleNamespace(payload={"id": "a"}, score=0.1),
SimpleNamespace(payload={"id": "b"}, score=0.2),
SimpleNamespace(id="a", payload={"id": "a"}, score=0.1),
SimpleNamespace(id="b", payload={"id": "b"}, score=0.2),
]
out = await tr.filter_top_k_events(relevant_events, scored_results)
assert [e["id"] for e in out] == ["a", "b"]
@ -179,8 +179,8 @@ async def test_get_context_with_time_range(mock_graph_engine, mock_vector_engine
}
]
mock_result1 = SimpleNamespace(payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(payload={"id": "e1"}, score=0.10)
mock_result1 = SimpleNamespace(id="e2", payload={"id": "e2"}, score=0.05)
mock_result2 = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.10)
mock_vector_engine.search.return_value = [mock_result1, mock_result2]
with (
@ -279,7 +279,7 @@ async def test_get_context_time_from_only(mock_graph_engine, mock_vector_engine)
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -313,7 +313,7 @@ async def test_get_context_time_to_only(mock_graph_engine, mock_vector_engine):
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -347,7 +347,7 @@ async def test_get_completion_without_context(mock_graph_engine, mock_vector_eng
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -416,7 +416,7 @@ async def test_get_completion_with_session(mock_graph_engine, mock_vector_engine
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
mock_user = MagicMock()
@ -481,7 +481,7 @@ async def test_get_completion_with_session_no_user_id(mock_graph_engine, mock_ve
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (
@ -570,7 +570,7 @@ async def test_get_completion_with_response_model(mock_graph_engine, mock_vector
}
]
mock_result = SimpleNamespace(payload={"id": "e1"}, score=0.05)
mock_result = SimpleNamespace(id="e1", payload={"id": "e1"}, score=0.05)
mock_vector_engine.search.return_value = [mock_result]
with (

View file

@ -6,6 +6,7 @@ from cognee.modules.retrieval.utils.brute_force_triplet_search import (
get_memory_fragment,
format_triplets,
)
from cognee.modules.engine.utils.generate_edge_id import generate_edge_id
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.modules.graph.exceptions.exceptions import EntityNotFoundError
from cognee.infrastructure.databases.vector.exceptions.exceptions import CollectionNotFoundError
@ -1036,9 +1037,11 @@ async def test_cognee_graph_mapping_batch_shapes():
]
}
edge_1_text = "relates_to"
edge_2_text = "relates_to"
edge_distances_batch = [
[MockScoredResult("edge1", 0.92, payload={"text": "relates_to"})],
[MockScoredResult("edge2", 0.88, payload={"text": "relates_to"})],
[MockScoredResult(generate_edge_id(edge_1_text), 0.92, payload={"text": edge_1_text})],
[MockScoredResult(generate_edge_id(edge_2_text), 0.88, payload={"text": edge_2_text})],
]
await graph.map_vector_distances_to_graph_nodes(

View file

@ -34,7 +34,9 @@ async def test_get_context_success(mock_vector_engine):
context = await retriever.get_context("test query")
assert context == "Alice knows Bob\nBob works at Tech Corp"
mock_vector_engine.search.assert_awaited_once_with("Triplet_text", "test query", limit=5)
mock_vector_engine.search.assert_awaited_once_with(
"Triplet_text", "test query", limit=5, include_payload=True
)
@pytest.mark.asyncio

View file

@ -0,0 +1,241 @@
"""Unit tests for usage logger core functions."""
import pytest
from datetime import datetime, timezone
from uuid import UUID
from types import SimpleNamespace
from cognee.shared.usage_logger import (
_sanitize_value,
_sanitize_dict_key,
_get_param_names,
_get_param_defaults,
_extract_user_id,
_extract_parameters,
log_usage,
)
from cognee.shared.exceptions import UsageLoggerError
class TestSanitizeValue:
"""Test _sanitize_value function."""
@pytest.mark.parametrize(
"value,expected",
[
(None, None),
("string", "string"),
(42, 42),
(3.14, 3.14),
(True, True),
(False, False),
],
)
def test_basic_types(self, value, expected):
assert _sanitize_value(value) == expected
def test_uuid_and_datetime(self):
"""Test UUID and datetime serialization."""
uuid_val = UUID("123e4567-e89b-12d3-a456-426614174000")
dt = datetime(2024, 1, 15, 12, 30, 45, tzinfo=timezone.utc)
assert _sanitize_value(uuid_val) == "123e4567-e89b-12d3-a456-426614174000"
assert _sanitize_value(dt) == "2024-01-15T12:30:45+00:00"
def test_collections(self):
"""Test list, tuple, and dict serialization."""
assert _sanitize_value(
[1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]
) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None]
assert _sanitize_value((1, "string", True)) == [1, "string", True]
assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == {
"key": "123e4567-e89b-12d3-a456-426614174000"
}
assert _sanitize_value([]) == []
assert _sanitize_value({}) == {}
def test_nested_and_complex(self):
"""Test nested structures and non-serializable types."""
# Nested structure
nested = {"level1": {"level2": {"level3": [1, 2, {"nested": "value"}]}}}
assert _sanitize_value(nested)["level1"]["level2"]["level3"][2]["nested"] == "value"
# Non-serializable
class CustomObject:
def __str__(self):
return "<CustomObject instance>"
result = _sanitize_value(CustomObject())
assert isinstance(result, str)
assert "<cannot be serialized" in result or "<CustomObject" in result
class TestSanitizeDictKey:
"""Test _sanitize_dict_key function."""
@pytest.mark.parametrize(
"key,expected_contains",
[
("simple_key", "simple_key"),
(UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"),
((1, 2, 3), ["1", "2"]),
],
)
def test_key_types(self, key, expected_contains):
result = _sanitize_dict_key(key)
assert isinstance(result, str)
if isinstance(expected_contains, list):
assert all(item in result for item in expected_contains)
else:
assert expected_contains in result
def test_non_serializable_key(self):
class BadKey:
def __str__(self):
return "<BadKey instance>"
result = _sanitize_dict_key(BadKey())
assert isinstance(result, str)
assert "<key:" in result or "<BadKey" in result
class TestGetParamNames:
"""Test _get_param_names function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b, c: None, ["a", "b", "c"]),
(lambda a, b=42, c="default": None, ["a", "b", "c"]),
(lambda a, **kwargs: None, ["a", "kwargs"]),
(lambda *args: None, ["args"]),
],
)
def test_param_extraction(self, func_def, expected):
assert _get_param_names(func_def) == expected
def test_async_function(self):
async def func(a, b):
pass
assert _get_param_names(func) == ["a", "b"]
class TestGetParamDefaults:
"""Test _get_param_defaults function."""
@pytest.mark.parametrize(
"func_def,expected",
[
(lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}),
(lambda a, b, c: None, {}),
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}),
],
)
def test_default_extraction(self, func_def, expected):
assert _get_param_defaults(func_def) == expected
class TestExtractUserId:
"""Test _extract_user_id function."""
def test_user_extraction(self):
"""Test extracting user_id from kwargs and args."""
user1 = SimpleNamespace(id=UUID("123e4567-e89b-12d3-a456-426614174000"))
user2 = SimpleNamespace(id="user-123")
# From kwargs
assert (
_extract_user_id((), {"user": user1}, ["user", "other"])
== "123e4567-e89b-12d3-a456-426614174000"
)
# From args
assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123"
# Not present
assert _extract_user_id(("arg1",), {}, ["param1"]) is None
# None value
assert _extract_user_id((None,), {}, ["user"]) is None
# No id attribute
assert _extract_user_id((SimpleNamespace(name="test"),), {}, ["user"]) is None
class TestExtractParameters:
"""Test _extract_parameters function."""
def test_parameter_extraction(self):
"""Test parameter extraction with various scenarios."""
def func1(param1, param2, user=None):
pass
def func2(param1, param2=42, param3="default", user=None):
pass
def func3():
pass
def func4(param1, user):
pass
# Kwargs only
result = _extract_parameters(
(), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1
)
assert result == {"param1": "v1", "param2": 42}
assert "user" not in result
# Args only
result = _extract_parameters(("v1", 42), {}, _get_param_names(func1), func1)
assert result == {"param1": "v1", "param2": 42}
# Mixed args/kwargs
result = _extract_parameters(("v1",), {"param3": "v3"}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param3"] == "v3"
# Defaults included
result = _extract_parameters(("v1",), {}, _get_param_names(func2), func2)
assert result["param1"] == "v1" and result["param2"] == 42 and result["param3"] == "default"
# No parameters
assert _extract_parameters((), {}, _get_param_names(func3), func3) == {}
# User excluded
user = SimpleNamespace(id="user-123")
result = _extract_parameters(("v1", user), {}, _get_param_names(func4), func4)
assert result == {"param1": "v1"} and "user" not in result
# Fallback when inspection fails
class BadFunc:
pass
result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc())
assert "arg_0" in result or "arg_1" in result
class TestDecoratorValidation:
"""Test decorator validation and behavior."""
def test_decorator_validation(self):
"""Test decorator validation and metadata preservation."""
# Sync function raises error
with pytest.raises(UsageLoggerError, match="requires an async function"):
@log_usage()
def sync_func():
pass
# Async function accepted
@log_usage()
async def async_func():
pass
assert callable(async_func)
# Metadata preserved
@log_usage(function_name="test_func", log_type="test")
async def test_func(param1: str, param2: int = 42):
"""Test docstring."""
return param1
assert test_func.__name__ == "test_func"
assert "Test docstring" in test_func.__doc__