ruff ruff

This commit is contained in:
hajdul88 2026-01-16 13:01:46 +01:00
parent 2e9f646edd
commit 1c96e3b469

View file

@ -20,14 +20,17 @@ from cognee.shared.exceptions import UsageLoggerError
class TestSanitizeValue: class TestSanitizeValue:
"""Test _sanitize_value function.""" """Test _sanitize_value function."""
@pytest.mark.parametrize("value,expected", [ @pytest.mark.parametrize(
(None, None), "value,expected",
("string", "string"), [
(42, 42), (None, None),
(3.14, 3.14), ("string", "string"),
(True, True), (42, 42),
(False, False), (3.14, 3.14),
]) (True, True),
(False, False),
],
)
def test_basic_types(self, value, expected): def test_basic_types(self, value, expected):
assert _sanitize_value(value) == expected assert _sanitize_value(value) == expected
@ -41,9 +44,9 @@ class TestSanitizeValue:
def test_collections(self): def test_collections(self):
"""Test list, tuple, and dict serialization.""" """Test list, tuple, and dict serialization."""
assert _sanitize_value([1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]) == [ assert _sanitize_value(
1, "string", "123e4567-e89b-12d3-a456-426614174000", None [1, "string", UUID("123e4567-e89b-12d3-a456-426614174000"), None]
] ) == [1, "string", "123e4567-e89b-12d3-a456-426614174000", None]
assert _sanitize_value((1, "string", True)) == [1, "string", True] assert _sanitize_value((1, "string", True)) == [1, "string", True]
assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == { assert _sanitize_value({"key": UUID("123e4567-e89b-12d3-a456-426614174000")}) == {
"key": "123e4567-e89b-12d3-a456-426614174000" "key": "123e4567-e89b-12d3-a456-426614174000"
@ -70,11 +73,14 @@ class TestSanitizeValue:
class TestSanitizeDictKey: class TestSanitizeDictKey:
"""Test _sanitize_dict_key function.""" """Test _sanitize_dict_key function."""
@pytest.mark.parametrize("key,expected_contains", [ @pytest.mark.parametrize(
("simple_key", "simple_key"), "key,expected_contains",
(UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"), [
((1, 2, 3), ["1", "2"]), ("simple_key", "simple_key"),
]) (UUID("123e4567-e89b-12d3-a456-426614174000"), "123e4567-e89b-12d3-a456-426614174000"),
((1, 2, 3), ["1", "2"]),
],
)
def test_key_types(self, key, expected_contains): def test_key_types(self, key, expected_contains):
result = _sanitize_dict_key(key) result = _sanitize_dict_key(key)
assert isinstance(result, str) assert isinstance(result, str)
@ -96,29 +102,36 @@ class TestSanitizeDictKey:
class TestGetParamNames: class TestGetParamNames:
"""Test _get_param_names function.""" """Test _get_param_names function."""
@pytest.mark.parametrize("func_def,expected", [ @pytest.mark.parametrize(
(lambda a, b, c: None, ["a", "b", "c"]), "func_def,expected",
(lambda a, b=42, c="default": None, ["a", "b", "c"]), [
(lambda a, **kwargs: None, ["a", "kwargs"]), (lambda a, b, c: None, ["a", "b", "c"]),
(lambda *args: None, ["args"]), (lambda a, b=42, c="default": None, ["a", "b", "c"]),
]) (lambda a, **kwargs: None, ["a", "kwargs"]),
(lambda *args: None, ["args"]),
],
)
def test_param_extraction(self, func_def, expected): def test_param_extraction(self, func_def, expected):
assert _get_param_names(func_def) == expected assert _get_param_names(func_def) == expected
def test_async_function(self): def test_async_function(self):
async def func(a, b): async def func(a, b):
pass pass
assert _get_param_names(func) == ["a", "b"] assert _get_param_names(func) == ["a", "b"]
class TestGetParamDefaults: class TestGetParamDefaults:
"""Test _get_param_defaults function.""" """Test _get_param_defaults function."""
@pytest.mark.parametrize("func_def,expected", [ @pytest.mark.parametrize(
(lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}), "func_def,expected",
(lambda a, b, c: None, {}), [
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}), (lambda a, b=42, c="default", d=None: None, {"b": 42, "c": "default", "d": None}),
]) (lambda a, b, c: None, {}),
(lambda a, b=10, c="test", d=None: None, {"b": 10, "c": "test", "d": None}),
],
)
def test_default_extraction(self, func_def, expected): def test_default_extraction(self, func_def, expected):
assert _get_param_defaults(func_def) == expected assert _get_param_defaults(func_def) == expected
@ -132,7 +145,10 @@ class TestExtractUserId:
user2 = SimpleNamespace(id="user-123") user2 = SimpleNamespace(id="user-123")
# From kwargs # From kwargs
assert _extract_user_id((), {"user": user1}, ["user", "other"]) == "123e4567-e89b-12d3-a456-426614174000" assert (
_extract_user_id((), {"user": user1}, ["user", "other"])
== "123e4567-e89b-12d3-a456-426614174000"
)
# From args # From args
assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123" assert _extract_user_id((user2, "other"), {}, ["user", "other"]) == "user-123"
# Not present # Not present
@ -148,17 +164,23 @@ class TestExtractParameters:
def test_parameter_extraction(self): def test_parameter_extraction(self):
"""Test parameter extraction with various scenarios.""" """Test parameter extraction with various scenarios."""
def func1(param1, param2, user=None): def func1(param1, param2, user=None):
pass pass
def func2(param1, param2=42, param3="default", user=None): def func2(param1, param2=42, param3="default", user=None):
pass pass
def func3(): def func3():
pass pass
def func4(param1, user): def func4(param1, user):
pass pass
# Kwargs only # Kwargs only
result = _extract_parameters((), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1) result = _extract_parameters(
(), {"param1": "v1", "param2": 42}, _get_param_names(func1), func1
)
assert result == {"param1": "v1", "param2": 42} assert result == {"param1": "v1", "param2": 42}
assert "user" not in result assert "user" not in result
@ -185,6 +207,7 @@ class TestExtractParameters:
# Fallback when inspection fails # Fallback when inspection fails
class BadFunc: class BadFunc:
pass pass
result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc()) result = _extract_parameters(("arg1", "arg2"), {}, [], BadFunc())
assert "arg_0" in result or "arg_1" in result assert "arg_0" in result or "arg_1" in result
@ -196,6 +219,7 @@ class TestDecoratorValidation:
"""Test decorator validation and metadata preservation.""" """Test decorator validation and metadata preservation."""
# Sync function raises error # Sync function raises error
with pytest.raises(UsageLoggerError, match="requires an async function"): with pytest.raises(UsageLoggerError, match="requires an async function"):
@log_usage() @log_usage()
def sync_func(): def sync_func():
pass pass
@ -204,6 +228,7 @@ class TestDecoratorValidation:
@log_usage() @log_usage()
async def async_func(): async def async_func():
pass pass
assert callable(async_func) assert callable(async_func)
# Metadata preserved # Metadata preserved