chore: cleaning usage logger logic

This commit is contained in:
hajdul88 2026-01-15 15:16:26 +01:00
parent e803f10417
commit bf2357e7bf

View file

@ -2,7 +2,7 @@ import asyncio
import inspect
import os
from datetime import datetime, timezone
from functools import wraps
from functools import singledispatch, wraps
from typing import Any, Callable, Optional
from uuid import UUID
@ -14,31 +14,9 @@ from cognee import __version__ as cognee_version
logger = get_logger("usage_logger")
@singledispatch
def _sanitize_value(value: Any) -> Any:
"""Ensure value is JSON serializable - converts non-serializable values to default messages."""
if value is None:
return None
if isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, (UUID,)):
return str(value)
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, (list, tuple)):
return [_sanitize_value(v) for v in value]
if isinstance(value, dict):
sanitized = {}
for k, v in value.items():
if isinstance(k, str):
key_str = k
else:
sanitized_key = _sanitize_value(k)
if isinstance(sanitized_key, str):
key_str = sanitized_key
else:
key_str = str(sanitized_key) if sanitized_key != "<cannot be serialized>" else f"<key:{type(k).__name__}>"
sanitized[key_str] = _sanitize_value(v)
return sanitized
"""Default handler for JSON serialization - converts to string."""
try:
str_repr = str(value)
if str_repr.startswith("<") and str_repr.endswith(">"):
@ -48,7 +26,64 @@ def _sanitize_value(value: Any) -> Any:
return f"<cannot be serialized: {type(value).__name__}>"
def _extract_user_id(args: tuple, kwargs: dict, func: Callable) -> Optional[str]:
@_sanitize_value.register(type(None))
def _(value: None) -> None:
return None
@_sanitize_value.register(str)
@_sanitize_value.register(int)
@_sanitize_value.register(float)
@_sanitize_value.register(bool)
def _(value: str | int | float | bool) -> str | int | float | bool:
return value
@_sanitize_value.register(UUID)
def _(value: UUID) -> str:
return str(value)
@_sanitize_value.register(datetime)
def _(value: datetime) -> str:
return value.isoformat()
@_sanitize_value.register(list)
@_sanitize_value.register(tuple)
def _(value: list | tuple) -> list:
return [_sanitize_value(v) for v in value]
@_sanitize_value.register(dict)
def _(value: dict) -> dict:
sanitized = {}
for k, v in value.items():
key_str = k if isinstance(k, str) else _sanitize_dict_key(k)
sanitized[key_str] = _sanitize_value(v)
return sanitized
def _sanitize_dict_key(key: Any) -> str:
"""Convert a non-string dict key to a string."""
sanitized_key = _sanitize_value(key)
if isinstance(sanitized_key, str):
# If it's a "cannot be serialized" message, use a key-specific message
if sanitized_key.startswith("<cannot be serialized"):
return f"<key:{type(key).__name__}>"
return sanitized_key
return str(sanitized_key)
def _get_param_names(func: Callable) -> list[str]:
"""Get parameter names from function signature."""
try:
return list(inspect.signature(func).parameters.keys())
except Exception:
return []
def _extract_user_id(args: tuple, kwargs: dict, param_names: list[str]) -> Optional[str]:
"""Extract user_id from function arguments if available."""
try:
if "user" in kwargs and kwargs["user"] is not None:
@ -56,8 +91,6 @@ def _extract_user_id(args: tuple, kwargs: dict, func: Callable) -> Optional[str]
if hasattr(user, "id"):
return str(user.id)
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
for i, param_name in enumerate(param_names):
if i < len(args) and param_name == "user":
user = args[i]
@ -68,7 +101,7 @@ def _extract_user_id(args: tuple, kwargs: dict, func: Callable) -> Optional[str]
return None
def _extract_parameters(args: tuple, kwargs: dict, func: Callable) -> dict:
def _extract_parameters(args: tuple, kwargs: dict, param_names: list[str]) -> dict:
"""Extract function parameters - captures all parameters, sanitizes for JSON."""
params = {}
@ -76,16 +109,14 @@ def _extract_parameters(args: tuple, kwargs: dict, func: Callable) -> dict:
if key != "user":
params[key] = _sanitize_value(value)
try:
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if param_names:
for i, param_name in enumerate(param_names):
if i < len(args) and param_name != "user" and param_name not in kwargs:
params[param_name] = _sanitize_value(args[i])
except Exception:
else:
# Fallback: capture all args by position if signature inspection fails
for i, arg_value in enumerate(args):
if i not in params.values():
params[f"arg_{i}"] = _sanitize_value(arg_value)
params[f"arg_{i}"] = _sanitize_value(arg_value)
return params
@ -184,8 +215,10 @@ def log_usage(function_name: Optional[str] = None, log_type: str = "function"):
# Capture start time
start_time = datetime.now(timezone.utc)
user_id = _extract_user_id(args, kwargs, func)
parameters = _extract_parameters(args, kwargs, func)
# Get param names once to avoid duplicate signature inspection
param_names = _get_param_names(func)
user_id = _extract_user_id(args, kwargs, param_names)
parameters = _extract_parameters(args, kwargs, param_names)
result = None
success = True