From d438b8f311832bd88a7027bdc6a70254cab6adce Mon Sep 17 00:00:00 2001 From: Mobile-Crest Date: Wed, 3 Dec 2025 12:27:37 +0100 Subject: [PATCH 1/2] add unit test for api utils --- test/unit_test/api/__init__.py | 15 + test/unit_test/api/test_common_utils.py | 164 ++++ test/unit_test/api/test_configs.py | 317 ++++++++ test/unit_test/api/test_json_encode.py | 452 +++++++++++ test/unit_test/api/test_validation_utils.py | 788 ++++++++++++++++++++ test/unit_test/conftest.py | 23 + 6 files changed, 1759 insertions(+) create mode 100644 test/unit_test/api/__init__.py create mode 100644 test/unit_test/api/test_common_utils.py create mode 100644 test/unit_test/api/test_configs.py create mode 100644 test/unit_test/api/test_json_encode.py create mode 100644 test/unit_test/api/test_validation_utils.py create mode 100644 test/unit_test/conftest.py diff --git a/test/unit_test/api/__init__.py b/test/unit_test/api/__init__.py new file mode 100644 index 000000000..177b91dd0 --- /dev/null +++ b/test/unit_test/api/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/test/unit_test/api/test_common_utils.py b/test/unit_test/api/test_common_utils.py new file mode 100644 index 000000000..05b5fc41c --- /dev/null +++ b/test/unit_test/api/test_common_utils.py @@ -0,0 +1,164 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.common module. +""" + +import pytest +from api.utils.common import string_to_bytes, bytes_to_string + + +class TestStringToBytes: + """Test cases for string_to_bytes function""" + + def test_string_input_returns_bytes(self): + """Test that string input is converted to bytes""" + input_string = "hello world" + result = string_to_bytes(input_string) + + assert isinstance(result, bytes) + assert result == b"hello world" + + def test_bytes_input_returns_same_bytes(self): + """Test that bytes input is returned unchanged""" + input_bytes = b"hello world" + result = string_to_bytes(input_bytes) + + assert isinstance(result, bytes) + assert result == input_bytes + assert result is input_bytes # Should be the same object + + def test_empty_string(self): + """Test conversion of empty string""" + result = string_to_bytes("") + + assert isinstance(result, bytes) + assert result == b"" + assert len(result) == 0 + + def test_unicode_characters(self): + """Test conversion of Unicode characters""" + input_string = "Hello δΈ–η•Œ 🌍" + result = string_to_bytes(input_string) + + assert isinstance(result, bytes) + # Verify it can be decoded back + assert result.decode("utf-8") == input_string + + def test_special_characters(self): + """Test conversion of special characters""" + input_string = "Hello, world! @#$%^&*()" + result = string_to_bytes(input_string) + + assert isinstance(result, bytes) + assert result.decode("utf-8") == input_string + + @pytest.mark.parametrize("input_val,expected", [ + ("test", b"test"), + ("", b""), + ("123", b"123"), + ("Hello World", b"Hello World"), + ]) + def test_various_string_inputs(self, input_val, expected): + """Test various string inputs""" + result = string_to_bytes(input_val) + assert result == expected + + +class TestBytesToString: + """Test cases for bytes_to_string function""" + + def test_bytes_input_returns_string(self): + """Test that bytes input is converted to string""" + input_bytes = b"hello world" + result = bytes_to_string(input_bytes) + + assert isinstance(result, str) + assert result == "hello world" + + def test_empty_bytes(self): + """Test conversion of empty bytes""" + result = bytes_to_string(b"") + + assert isinstance(result, str) + assert result == "" + assert len(result) == 0 + + def test_unicode_bytes(self): + """Test conversion of Unicode bytes""" + input_bytes = "Hello δΈ–η•Œ 🌍".encode("utf-8") + result = bytes_to_string(input_bytes) + + assert isinstance(result, str) + assert result == "Hello δΈ–η•Œ 🌍" + + @pytest.mark.parametrize("input_bytes,expected", [ + (b"test", "test"), + (b"", ""), + (b"123", "123"), + (b"Hello World", "Hello World"), + ]) + def test_various_bytes_inputs(self, input_bytes, expected): + """Test various bytes inputs""" + result = bytes_to_string(input_bytes) + assert result == expected + + def test_invalid_utf8_raises_error(self): + """Test that invalid UTF-8 bytes raise an error""" + # Invalid UTF-8 sequence + invalid_bytes = b"\xff\xfe" + + with pytest.raises(UnicodeDecodeError): + bytes_to_string(invalid_bytes) + + +class TestRoundtripConversion: + """Test roundtrip conversions between string and bytes""" + + def test_string_to_bytes_to_string(self): + """Test converting string to bytes and back""" + original = "Hello, World! δΈ–η•Œ" + + as_bytes = string_to_bytes(original) + back_to_string = bytes_to_string(as_bytes) + + assert back_to_string == original + + def test_bytes_to_string_to_bytes(self): + """Test converting bytes to string and back""" + original = b"Hello, World!" + + as_string = bytes_to_string(original) + back_to_bytes = string_to_bytes(as_string) + + assert back_to_bytes == original + + @pytest.mark.parametrize("test_string", [ + "Simple text", + "Unicode: δ½ ε₯½δΈ–η•Œ 🌍", + "Special: !@#$%^&*()", + "Multiline\nWith\tTabs", + "", + ]) + def test_roundtrip_various_strings(self, test_string): + """Test roundtrip conversion for various strings""" + result = bytes_to_string(string_to_bytes(test_string)) + assert result == test_string + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_configs.py b/test/unit_test/api/test_configs.py new file mode 100644 index 000000000..7bcc5b0b7 --- /dev/null +++ b/test/unit_test/api/test_configs.py @@ -0,0 +1,317 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.configs module. +""" + +import pytest +import pickle +import base64 +from unittest.mock import patch +from api.utils.configs import ( + serialize_b64, + deserialize_b64, + RestrictedUnpickler, + restricted_loads, + safe_module +) + + +class TestSerializeB64: + """Test cases for serialize_b64 function""" + + def test_serialize_dict(self): + """Test serialization of a dictionary""" + test_dict = {"key": "value", "number": 42} + result = serialize_b64(test_dict) + + assert isinstance(result, bytes) + # Should be valid base64 + decoded = base64.b64decode(result) + assert isinstance(decoded, bytes) + + def test_serialize_list(self): + """Test serialization of a list""" + test_list = [1, 2, 3, "test", {"nested": "dict"}] + result = serialize_b64(test_list) + + assert isinstance(result, bytes) + + def test_serialize_with_to_str_false(self): + """Test serialization with to_str=False returns bytes""" + test_data = {"test": "data"} + result = serialize_b64(test_data, to_str=False) + + assert isinstance(result, bytes) + + def test_serialize_with_to_str_true(self): + """Test serialization with to_str=True returns string""" + test_data = {"test": "data"} + result = serialize_b64(test_data, to_str=True) + + assert isinstance(result, str) + # Should be valid base64 string + base64.b64decode(result) # Should not raise + + def test_serialize_string(self): + """Test serialization of a string""" + test_string = "Hello, World!" + result = serialize_b64(test_string) + + assert isinstance(result, bytes) + + def test_serialize_number(self): + """Test serialization of numbers""" + test_int = 12345 + result = serialize_b64(test_int) + + assert isinstance(result, bytes) + + def test_serialize_complex_nested_structure(self): + """Test serialization of complex nested structures""" + test_data = { + "list": [1, 2, 3], + "dict": {"nested": {"deep": "value"}}, + "tuple": (1, 2, 3), + "string": "test", + "number": 42.5 + } + result = serialize_b64(test_data) + + assert isinstance(result, bytes) + + def test_serialize_none(self): + """Test serialization of None""" + result = serialize_b64(None) + + assert isinstance(result, bytes) + + def test_serialize_empty_dict(self): + """Test serialization of empty dictionary""" + result = serialize_b64({}) + + assert isinstance(result, bytes) + + def test_serialize_empty_list(self): + """Test serialization of empty list""" + result = serialize_b64([]) + + assert isinstance(result, bytes) + + +class TestDeserializeB64: + """Test cases for deserialize_b64 function""" + + def test_deserialize_string_input(self): + """Test deserialization with string input""" + test_data = {"key": "value"} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + + def test_deserialize_bytes_input(self): + """Test deserialization with bytes input""" + test_data = {"key": "value"} + serialized = serialize_b64(test_data, to_str=False) + + result = deserialize_b64(serialized) + + assert result == test_data + + @patch('api.utils.configs.get_base_config') + def test_deserialize_with_safe_module_disabled(self, mock_config): + """Test deserialization with safe module checking disabled""" + mock_config.return_value = False + + test_data = {"test": "data"} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + mock_config.assert_called_once_with('use_deserialize_safe_module', False) + + @patch('api.utils.configs.get_base_config') + def test_deserialize_with_safe_module_enabled(self, mock_config): + """Test deserialization with safe module checking enabled""" + mock_config.return_value = True + + # Simple data that doesn't require unsafe modules + test_data = {"test": "data", "number": 42} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + + def test_roundtrip_serialization(self): + """Test complete roundtrip serialization and deserialization""" + test_data = { + "string": "test", + "number": 123, + "list": [1, 2, 3], + "nested": {"key": "value"} + } + + serialized = serialize_b64(test_data, to_str=True) + deserialized = deserialize_b64(serialized) + + assert deserialized == test_data + + @pytest.mark.parametrize("test_data", [ + {"key": "value"}, + [1, 2, 3, 4, 5], + "simple string", + 42, + 3.14, + None, + {"nested": {"deep": {"structure": "value"}}}, + ]) + def test_roundtrip_various_data_types(self, test_data): + """Test roundtrip for various data types""" + serialized = serialize_b64(test_data) + deserialized = deserialize_b64(serialized) + + assert deserialized == test_data + + +class TestRestrictedUnpickler: + """Test cases for RestrictedUnpickler class""" + + def test_allows_safe_modules(self): + """Test that safe modules are allowed""" + # Create a simple object that would be in a safe module context + test_data = {"test": "data"} + pickled = pickle.dumps(test_data) + + # This should work without raising + result = restricted_loads(pickled) + assert result == test_data + + @patch('api.utils.configs.get_base_config') + def test_blocks_unsafe_modules(self, mock_config): + """Test that unsafe modules are blocked""" + mock_config.return_value = True + + # Try to pickle something from an unsafe module + # We'll simulate this by creating a pickle that references os module + class UnsafeClass: + __module__ = 'os.path' + def __reduce__(self): + return (eval, ("1+1",)) + + try: + unsafe_obj = UnsafeClass() + pickled = pickle.dumps(unsafe_obj) + + with pytest.raises(pickle.UnpicklingError): + restricted_loads(pickled) + except: + # If we can't create the unsafe pickle, skip this test + pytest.skip("Unable to create unsafe pickle for testing") + + def test_safe_module_set_contains_expected_modules(self): + """Test that safe_module set contains expected modules""" + assert 'numpy' in safe_module + assert 'rag_flow' in safe_module + + def test_restricted_loads_with_safe_data(self): + """Test restricted_loads with safe data""" + test_data = [1, 2, 3, "test", {"key": "value"}] + pickled = pickle.dumps(test_data) + + result = restricted_loads(pickled) + + assert result == test_data + + +class TestIntegrationScenarios: + """Integration tests for serialization/deserialization workflows""" + + def test_serialize_deserialize_workflow(self): + """Test complete workflow of serialize and deserialize""" + original_data = { + "user": "test_user", + "settings": { + "theme": "dark", + "notifications": True + }, + "items": [1, 2, 3, 4, 5] + } + + # Serialize to string + serialized_str = serialize_b64(original_data, to_str=True) + assert isinstance(serialized_str, str) + + # Deserialize back + deserialized = deserialize_b64(serialized_str) + assert deserialized == original_data + + def test_serialize_deserialize_with_bytes(self): + """Test workflow using bytes format""" + original_data = {"test": "data", "number": 42} + + # Serialize to bytes + serialized_bytes = serialize_b64(original_data, to_str=False) + assert isinstance(serialized_bytes, bytes) + + # Deserialize back + deserialized = deserialize_b64(serialized_bytes) + assert deserialized == original_data + + @patch('api.utils.configs.get_base_config') + def test_safe_deserialization_workflow(self, mock_config): + """Test safe deserialization workflow""" + mock_config.return_value = True + + test_data = {"safe": "data"} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + + def test_empty_data_workflow(self): + """Test workflow with empty data""" + empty_dict = {} + + serialized = serialize_b64(empty_dict, to_str=True) + deserialized = deserialize_b64(serialized) + + assert deserialized == empty_dict + + def test_large_data_workflow(self): + """Test workflow with larger data structures""" + large_data = { + f"key_{i}": { + "value": i, + "list": list(range(10)), + "nested": {"deep": f"value_{i}"} + } + for i in range(100) + } + + serialized = serialize_b64(large_data, to_str=True) + deserialized = deserialize_b64(serialized) + + assert deserialized == large_data + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_json_encode.py b/test/unit_test/api/test_json_encode.py new file mode 100644 index 000000000..0114137da --- /dev/null +++ b/test/unit_test/api/test_json_encode.py @@ -0,0 +1,452 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.json_encode module. +""" + +import pytest +import json +import datetime +from enum import Enum, IntEnum +from api.utils.json_encode import ( + BaseType, + CustomJSONEncoder, + json_dumps, + json_loads +) + + +class TestBaseTypeToDict: + """Test cases for BaseType.to_dict method""" + + def test_simple_object_to_dict(self): + """Test converting simple object to dictionary""" + class SimpleType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = SimpleType() + result = obj.to_dict() + + assert isinstance(result, dict) + assert result == {"name": "test", "value": 42} + + def test_private_attributes_excluded(self): + """Test that private attributes (starting with _) are stripped""" + class PrivateType(BaseType): + def __init__(self): + self._private = "hidden" + self._another = "also hidden" + self.public = "visible" + + obj = PrivateType() + result = obj.to_dict() + + # Private attributes should have _ stripped in keys + assert "private" in result + assert "another" in result + assert "public" in result + assert result["private"] == "hidden" + assert result["another"] == "also hidden" + assert result["public"] == "visible" + + def test_empty_object(self): + """Test converting empty object to dictionary""" + class EmptyType(BaseType): + pass + + obj = EmptyType() + result = obj.to_dict() + + assert isinstance(result, dict) + assert len(result) == 0 + + def test_nested_object_to_dict(self): + """Test converting object with nested values""" + class NestedType(BaseType): + def __init__(self): + self.data = {"key": "value"} + self.items = [1, 2, 3] + self.text = "test" + + obj = NestedType() + result = obj.to_dict() + + assert result["data"] == {"key": "value"} + assert result["items"] == [1, 2, 3] + assert result["text"] == "test" + + +class TestBaseTypeToDictWithType: + """Test cases for BaseType.to_dict_with_type method""" + + def test_includes_type_information(self): + """Test that type information is included""" + class TypedObject(BaseType): + def __init__(self): + self.value = "test" + + obj = TypedObject() + result = obj.to_dict_with_type() + + assert "type" in result + assert "data" in result + assert result["type"] == "TypedObject" + + def test_includes_module_information(self): + """Test that module information is included""" + class ModuleObject(BaseType): + def __init__(self): + self.value = "test" + + obj = ModuleObject() + result = obj.to_dict_with_type() + + assert "module" in result + assert result["module"] is not None + + def test_nested_objects_with_types(self): + """Test nested BaseType objects include type info""" + class InnerType(BaseType): + def __init__(self): + self.inner_value = "inner" + + class OuterType(BaseType): + def __init__(self): + self.nested = InnerType() + self.value = "outer" + + obj = OuterType() + result = obj.to_dict_with_type() + + assert result["type"] == "OuterType" + assert "data" in result + assert "nested" in result["data"] + assert result["data"]["nested"]["type"] == "InnerType" + + def test_list_handling(self): + """Test handling of lists in to_dict_with_type""" + class ListType(BaseType): + def __init__(self): + self.items = [1, 2, 3] + + obj = ListType() + result = obj.to_dict_with_type() + + assert "items" in result["data"] + items_data = result["data"]["items"] + assert items_data["type"] == "list" + assert isinstance(items_data["data"], list) + + def test_dict_handling(self): + """Test handling of dictionaries in to_dict_with_type""" + class DictType(BaseType): + def __init__(self): + self.config = {"key": "value"} + + obj = DictType() + result = obj.to_dict_with_type() + + assert "config" in result["data"] + config_data = result["data"]["config"] + assert config_data["type"] == "dict" + + +class TestCustomJSONEncoder: + """Test cases for CustomJSONEncoder class""" + + def test_encode_datetime(self): + """Test encoding of datetime objects""" + dt = datetime.datetime(2025, 12, 3, 14, 30, 45) + result = json.dumps(dt, cls=CustomJSONEncoder) + + assert result == '"2025-12-03 14:30:45"' + + def test_encode_date(self): + """Test encoding of date objects""" + d = datetime.date(2025, 12, 3) + result = json.dumps(d, cls=CustomJSONEncoder) + + assert result == '"2025-12-03"' + + def test_encode_timedelta(self): + """Test encoding of timedelta objects""" + td = datetime.timedelta(days=1, hours=2, minutes=30) + result = json.dumps(td, cls=CustomJSONEncoder) + + assert isinstance(json.loads(result), str) + assert "1 day" in json.loads(result) + + def test_encode_enum(self): + """Test encoding of Enum objects""" + class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + result = json.dumps(Color.RED, cls=CustomJSONEncoder) + + assert result == '"red"' + + def test_encode_int_enum(self): + """Test encoding of IntEnum objects""" + class Priority(IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + result = json.dumps(Priority.HIGH, cls=CustomJSONEncoder) + + assert result == '3' + + def test_encode_set(self): + """Test encoding of set objects""" + test_set = {1, 2, 3, 4, 5} + result = json.dumps(test_set, cls=CustomJSONEncoder) + + # Set should be converted to list + decoded = json.loads(result) + assert isinstance(decoded, list) + assert set(decoded) == test_set + + def test_encode_basetype_object(self): + """Test encoding of BaseType objects""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = TestType() + result = json.dumps(obj, cls=CustomJSONEncoder) + + decoded = json.loads(result) + assert decoded["name"] == "test" + assert decoded["value"] == 42 + + def test_encode_basetype_with_type_flag(self): + """Test encoding BaseType with with_type flag""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + encoder = CustomJSONEncoder(with_type=True) + result = json.dumps(obj, cls=CustomJSONEncoder, with_type=True) + + decoded = json.loads(result) + assert "type" in decoded + assert "data" in decoded + + def test_encode_type_object(self): + """Test encoding of type objects""" + result = json.dumps(str, cls=CustomJSONEncoder) + + assert result == '"str"' + + def test_encode_nested_structures(self): + """Test encoding of nested structures with various types""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + + data = { + "datetime": datetime.datetime(2025, 12, 3, 12, 0, 0), + "date": datetime.date(2025, 12, 3), + "set": {1, 2, 3}, + "object": TestType(), + "list": [1, 2, 3] + } + + result = json.dumps(data, cls=CustomJSONEncoder) + decoded = json.loads(result) + + assert decoded["datetime"] == "2025-12-03 12:00:00" + assert decoded["date"] == "2025-12-03" + assert isinstance(decoded["set"], list) + assert decoded["object"]["name"] == "test" + + +class TestJsonDumps: + """Test cases for json_dumps function""" + + def test_json_dumps_basic(self): + """Test basic json_dumps functionality""" + data = {"key": "value", "number": 42} + result = json_dumps(data) + + assert isinstance(result, str) + assert json.loads(result) == data + + def test_json_dumps_with_byte_false(self): + """Test json_dumps with byte=False returns string""" + data = {"test": "data"} + result = json_dumps(data, byte=False) + + assert isinstance(result, str) + + def test_json_dumps_with_byte_true(self): + """Test json_dumps with byte=True returns bytes""" + data = {"test": "data"} + result = json_dumps(data, byte=True) + + assert isinstance(result, bytes) + + def test_json_dumps_with_indent(self): + """Test json_dumps with indentation""" + data = {"key": "value"} + result = json_dumps(data, indent=2) + + assert isinstance(result, str) + assert "\n" in result # Indented JSON has newlines + + def test_json_dumps_with_type_false(self): + """Test json_dumps with with_type=False""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + result = json_dumps(obj, with_type=False) + + decoded = json.loads(result) + assert "type" not in decoded + assert decoded["value"] == "test" + + def test_json_dumps_with_type_true(self): + """Test json_dumps with with_type=True""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + result = json_dumps(obj, with_type=True) + + decoded = json.loads(result) + assert "type" in decoded + assert "data" in decoded + + def test_json_dumps_datetime(self): + """Test json_dumps with datetime objects""" + data = { + "timestamp": datetime.datetime(2025, 12, 3, 15, 30, 0) + } + result = json_dumps(data) + + decoded = json.loads(result) + assert decoded["timestamp"] == "2025-12-03 15:30:00" + + +class TestJsonLoads: + """Test cases for json_loads function""" + + def test_json_loads_string_input(self): + """Test json_loads with string input""" + json_string = '{"key": "value", "number": 42}' + result = json_loads(json_string) + + assert isinstance(result, dict) + assert result["key"] == "value" + assert result["number"] == 42 + + def test_json_loads_bytes_input(self): + """Test json_loads with bytes input""" + json_bytes = b'{"key": "value"}' + result = json_loads(json_bytes) + + assert isinstance(result, dict) + assert result["key"] == "value" + + def test_json_loads_with_object_hook(self): + """Test json_loads with object_hook parameter""" + def custom_hook(obj): + if "special" in obj: + obj["processed"] = True + return obj + + json_string = '{"special": "value"}' + result = json_loads(json_string, object_hook=custom_hook) + + assert result["processed"] is True + + def test_json_loads_empty_object(self): + """Test json_loads with empty object""" + result = json_loads('{}') + + assert isinstance(result, dict) + assert len(result) == 0 + + def test_json_loads_array(self): + """Test json_loads with array""" + result = json_loads('[1, 2, 3, 4, 5]') + + assert isinstance(result, list) + assert result == [1, 2, 3, 4, 5] + + +class TestRoundtripConversion: + """Test roundtrip conversions between dumps and loads""" + + def test_roundtrip_dict(self): + """Test roundtrip conversion of dictionary""" + original = {"key": "value", "number": 42, "list": [1, 2, 3]} + + dumped = json_dumps(original) + loaded = json_loads(dumped) + + assert loaded == original + + def test_roundtrip_with_bytes(self): + """Test roundtrip with byte conversion""" + original = {"test": "data"} + + dumped = json_dumps(original, byte=True) + loaded = json_loads(dumped) + + assert loaded == original + + def test_roundtrip_basetype(self): + """Test roundtrip with BaseType object""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = TestType() + + dumped = json_dumps(obj) + loaded = json_loads(dumped) + + assert loaded["name"] == "test" + assert loaded["value"] == 42 + + @pytest.mark.parametrize("test_data", [ + {"simple": "dict"}, + [1, 2, 3, 4, 5], + {"nested": {"deep": {"structure": "value"}}}, + {"mixed": [1, "two", 3.0, {"four": 4}]}, + ]) + def test_roundtrip_various_structures(self, test_data): + """Test roundtrip for various data structures""" + dumped = json_dumps(test_data) + loaded = json_loads(dumped) + + assert loaded == test_data + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_validation_utils.py b/test/unit_test/api/test_validation_utils.py new file mode 100644 index 000000000..32aba294c --- /dev/null +++ b/test/unit_test/api/test_validation_utils.py @@ -0,0 +1,788 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.validation_utils module. +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from uuid import UUID, uuid1 +from pydantic import BaseModel, Field, ValidationError +from werkzeug.exceptions import BadRequest, UnsupportedMediaType + +from api.utils.validation_utils import ( + validate_and_parse_json_request, + validate_and_parse_request_args, + format_validation_error_message, + normalize_str, + validate_uuid1_hex, + CreateDatasetReq, + UpdateDatasetReq, + DeleteDatasetReq, + ListDatasetReq, + ParserConfig, + RaptorConfig, + GraphragConfig, +) + + +class TestNormalizeStr: + """Test cases for normalize_str function""" + + def test_normalize_string_with_whitespace(self): + """Test normalization of string with leading/trailing whitespace""" + result = normalize_str(" Admin ") + assert result == "admin" + + def test_normalize_string_uppercase(self): + """Test normalization converts to lowercase""" + result = normalize_str("UPPERCASE") + assert result == "uppercase" + + def test_normalize_mixed_case(self): + """Test normalization of mixed case string""" + result = normalize_str(" MiXeD CaSe ") + assert result == "mixed case" + + def test_normalize_empty_string(self): + """Test normalization of empty string""" + result = normalize_str("") + assert result == "" + + def test_normalize_whitespace_only(self): + """Test normalization of whitespace-only string""" + result = normalize_str(" ") + assert result == "" + + def test_preserve_non_string_integer(self): + """Test that integers are preserved""" + result = normalize_str(42) + assert result == 42 + + def test_preserve_non_string_none(self): + """Test that None is preserved""" + result = normalize_str(None) + assert result is None + + def test_preserve_non_string_list(self): + """Test that lists are preserved""" + input_list = ["User", "Admin"] + result = normalize_str(input_list) + assert result == input_list + + def test_preserve_non_string_dict(self): + """Test that dicts are preserved""" + input_dict = {"role": "Admin"} + result = normalize_str(input_dict) + assert result == input_dict + + @pytest.mark.parametrize("input_val,expected", [ + ("ReadOnly", "readonly"), + (" ADMIN ", "admin"), + ("User", "user"), + ("", ""), + (123, 123), + (False, False), + (0, 0), + ]) + def test_various_inputs(self, input_val, expected): + """Test various input types""" + result = normalize_str(input_val) + assert result == expected + + +class TestValidateUuid1Hex: + """Test cases for validate_uuid1_hex function""" + + def test_valid_uuid1_string(self): + """Test validation of valid UUID1 string""" + uuid1_obj = uuid1() + uuid1_str = str(uuid1_obj) + + result = validate_uuid1_hex(uuid1_str) + + assert isinstance(result, str) + assert len(result) == 32 + assert result == uuid1_obj.hex + + def test_valid_uuid1_object(self): + """Test validation of valid UUID1 object""" + uuid1_obj = uuid1() + + result = validate_uuid1_hex(uuid1_obj) + + assert isinstance(result, str) + assert result == uuid1_obj.hex + + def test_uuid1_hex_no_hyphens(self): + """Test that result has no hyphens""" + uuid1_obj = uuid1() + result = validate_uuid1_hex(uuid1_obj) + + assert "-" not in result + + def test_invalid_uuid_string_raises_error(self): + """Test that invalid UUID string raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError) as exc_info: + validate_uuid1_hex("not-a-uuid") + + assert exc_info.value.type == "invalid_UUID1_format" + + def test_non_uuid1_version_raises_error(self): + """Test that non-UUID1 version raises error""" + from pydantic_core import PydanticCustomError + from uuid import uuid4 + + uuid4_obj = uuid4() + + with pytest.raises(PydanticCustomError) as exc_info: + validate_uuid1_hex(uuid4_obj) + + assert exc_info.value.type == "invalid_UUID1_format" + + def test_integer_input_raises_error(self): + """Test that integer input raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError): + validate_uuid1_hex(12345) + + def test_none_input_raises_error(self): + """Test that None input raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError): + validate_uuid1_hex(None) + + +class TestFormatValidationErrorMessage: + """Test cases for format_validation_error_message function""" + + def test_single_validation_error(self): + """Test formatting of single validation error""" + class TestModel(BaseModel): + name: str + + try: + TestModel(name=123) + except ValidationError as e: + result = format_validation_error_message(e) + + assert "Field: " in result + assert "Message:" in result + assert "Value: <123>" in result + + def test_multiple_validation_errors(self): + """Test formatting of multiple validation errors""" + class TestModel(BaseModel): + name: str + age: int + + try: + TestModel(name=123, age="not_an_int") + except ValidationError as e: + result = format_validation_error_message(e) + + assert "Field: " in result + assert "Field: " in result + assert "\n" in result # Multiple errors separated by newlines + + def test_long_value_truncation(self): + """Test that long values are truncated""" + class TestModel(BaseModel): + text: str + + long_value = "x" * 200 + + try: + TestModel(text=123) # Wrong type to trigger error + except ValidationError as e: + # Manually create error with long value + pass + + # Create a model with max_length constraint + class TestModel2(BaseModel): + text: str = Field(max_length=10) + + try: + TestModel2(text="x" * 200) + except ValidationError as e: + result = format_validation_error_message(e) + + # Check that value is truncated + assert "..." in result or len(result) < 500 + + def test_nested_field_error(self): + """Test formatting of nested field errors""" + class NestedModel(BaseModel): + value: int + + class ParentModel(BaseModel): + nested: NestedModel + + try: + ParentModel(nested={"value": "not_an_int"}) + except ValidationError as e: + result = format_validation_error_message(e) + + assert "nested.value" in result + + +class TestValidateAndParseRequestArgs: + """Test cases for validate_and_parse_request_args function""" + + def test_valid_request_args(self): + """Test validation of valid request arguments""" + class TestValidator(BaseModel): + param1: str + param2: int = 10 + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"param1": "value", "param2": "20"} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert error is None + assert result is not None + assert result["param1"] == "value" + assert result["param2"] == 20 + + def test_missing_required_field(self): + """Test validation with missing required field""" + class TestValidator(BaseModel): + required_field: str + + mock_request = Mock() + mock_request.args.to_dict.return_value = {} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "required_field" in error + + def test_with_extras_parameter(self): + """Test validation with extras parameter""" + class TestValidator(BaseModel): + param1: str + internal_id: int + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"param1": "value"} + + result, error = validate_and_parse_request_args( + mock_request, + TestValidator, + extras={"internal_id": 123} + ) + + assert error is None + assert result is not None + assert result["param1"] == "value" + assert "internal_id" not in result # Extras should be removed + + def test_type_conversion(self): + """Test that Pydantic performs type conversion""" + class TestValidator(BaseModel): + number: int + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"number": "42"} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert error is None + assert result["number"] == 42 + assert isinstance(result["number"], int) + + +class TestValidateAndParseJsonRequest: + """Test cases for validate_and_parse_json_request function""" + + @pytest.mark.anyio + async def test_valid_json_request(self): + """Test validation of valid JSON request""" + class TestValidator(BaseModel): + name: str + value: int + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test", "value": 42}) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert error is None + assert result is not None + assert result["name"] == "test" + assert result["value"] == 42 + + @pytest.mark.anyio + async def test_unsupported_content_type(self): + """Test handling of unsupported content type""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(side_effect=UnsupportedMediaType()) + mock_request.content_type = "text/xml" + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Unsupported content type" in error + assert "text/xml" in error + + @pytest.mark.anyio + async def test_malformed_json(self): + """Test handling of malformed JSON""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(side_effect=BadRequest()) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Malformed JSON syntax" in error + + @pytest.mark.anyio + async def test_invalid_payload_type(self): + """Test handling of non-dict payload""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value=["not", "a", "dict"]) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Invalid request payload" in error + assert "list" in error + + @pytest.mark.anyio + async def test_validation_error(self): + """Test handling of Pydantic validation errors""" + class TestValidator(BaseModel): + name: str + age: int + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": 123, "age": "not_int"}) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Field:" in error + + @pytest.mark.anyio + async def test_with_extras_parameter(self): + """Test validation with extras parameter""" + class TestValidator(BaseModel): + name: str + user_id: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test"}) + + result, error = await validate_and_parse_json_request( + mock_request, + TestValidator, + extras={"user_id": "user_123"} + ) + + assert error is None + assert result is not None + assert result["name"] == "test" + assert "user_id" not in result # Extras should be removed + + @pytest.mark.anyio + async def test_exclude_unset_parameter(self): + """Test exclude_unset parameter""" + class TestValidator(BaseModel): + name: str + optional: str = "default" + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test"}) + + result, error = await validate_and_parse_json_request( + mock_request, + TestValidator, + exclude_unset=True + ) + + assert error is None + assert result is not None + assert "name" in result + assert "optional" not in result # Not set, should be excluded + + +class TestCreateDatasetReq: + """Test cases for CreateDatasetReq validation""" + + def test_valid_dataset_creation(self): + """Test valid dataset creation request""" + data = { + "name": "Test Dataset", + "embedding_model": "text-embedding-3-large@openai" + } + + dataset = CreateDatasetReq(**data) + + assert dataset.name == "Test Dataset" + assert dataset.embedding_model == "text-embedding-3-large@openai" + + def test_name_whitespace_stripping(self): + """Test that name whitespace is stripped""" + data = { + "name": " Test Dataset " + } + + dataset = CreateDatasetReq(**data) + + assert dataset.name == "Test Dataset" + + def test_empty_name_raises_error(self): + """Test that empty name raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq(name="") + + def test_invalid_embedding_model_format(self): + """Test that invalid embedding model format raises error""" + with pytest.raises(ValidationError) as exc_info: + CreateDatasetReq( + name="Test", + embedding_model="invalid_model" + ) + + assert "format_invalid" in str(exc_info.value) + + def test_embedding_model_without_provider(self): + """Test embedding model without provider raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + embedding_model="model_name@" + ) + + def test_valid_avatar_base64(self): + """Test valid base64 avatar""" + data = { + "name": "Test", + "avatar": "" + } + + dataset = CreateDatasetReq(**data) + + assert dataset.avatar.startswith("" + ) + + def test_avatar_missing_data_prefix(self): + """Test avatar missing data: prefix raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + avatar="image/png;base64,abc123" + ) + + def test_default_chunk_method(self): + """Test default chunk_method is set to 'naive'""" + dataset = CreateDatasetReq(name="Test") + + assert dataset.chunk_method == "naive" + + def test_valid_chunk_methods(self): + """Test various valid chunk methods""" + valid_methods = ["naive", "book", "email", "laws", "manual", "one", + "paper", "picture", "presentation", "qa", "table", "tag"] + + for method in valid_methods: + dataset = CreateDatasetReq(name="Test", chunk_method=method) + assert dataset.chunk_method == method + + def test_invalid_chunk_method(self): + """Test invalid chunk method raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq(name="Test", chunk_method="invalid_method") + + def test_pipeline_id_validation(self): + """Test pipeline_id validation""" + # Valid 32-char hex + dataset = CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="a" * 32 + ) + assert dataset.pipeline_id == "a" * 32 + + def test_pipeline_id_wrong_length(self): + """Test pipeline_id with wrong length raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="abc" # Too short + ) + + def test_pipeline_id_non_hex(self): + """Test pipeline_id with non-hex characters raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="g" * 32 # 'g' is not hex + ) + + +class TestUpdateDatasetReq: + """Test cases for UpdateDatasetReq validation""" + + def test_valid_update_request(self): + """Test valid dataset update request""" + uuid1_obj = uuid1() + + data = { + "dataset_id": str(uuid1_obj), + "name": "Updated Dataset" + } + + dataset = UpdateDatasetReq(**data) + + assert dataset.dataset_id == uuid1_obj.hex + assert dataset.name == "Updated Dataset" + + def test_dataset_id_uuid1_validation(self): + """Test that dataset_id must be UUID1""" + from uuid import uuid4 + + with pytest.raises(ValidationError): + UpdateDatasetReq( + dataset_id=str(uuid4()), + name="Test" + ) + + def test_pagerank_validation(self): + """Test pagerank field validation""" + uuid1_obj = uuid1() + + dataset = UpdateDatasetReq( + dataset_id=str(uuid1_obj), + name="Test", + pagerank=50 + ) + + assert dataset.pagerank == 50 + + def test_pagerank_out_of_range(self): + """Test pagerank out of range raises error""" + uuid1_obj = uuid1() + + with pytest.raises(ValidationError): + UpdateDatasetReq( + dataset_id=str(uuid1_obj), + name="Test", + pagerank=101 # Max is 100 + ) + + +class TestDeleteDatasetReq: + """Test cases for DeleteDatasetReq validation""" + + def test_valid_delete_request(self): + """Test valid delete request""" + uuid1_obj1 = uuid1() + uuid1_obj2 = uuid1() + + req = DeleteDatasetReq(ids=[str(uuid1_obj1), str(uuid1_obj2)]) + + assert len(req.ids) == 2 + assert uuid1_obj1.hex in req.ids + assert uuid1_obj2.hex in req.ids + + def test_duplicate_ids_raises_error(self): + """Test that duplicate IDs raise error""" + uuid1_obj = uuid1() + + with pytest.raises(ValidationError) as exc_info: + DeleteDatasetReq(ids=[str(uuid1_obj), str(uuid1_obj)]) + + assert "duplicate" in str(exc_info.value).lower() + + def test_empty_ids_list(self): + """Test empty IDs list""" + req = DeleteDatasetReq(ids=[]) + + assert req.ids == [] + + def test_none_ids(self): + """Test None IDs""" + req = DeleteDatasetReq(ids=None) + + assert req.ids is None + + +class TestListDatasetReq: + """Test cases for ListDatasetReq validation""" + + def test_default_values(self): + """Test default values for list request""" + req = ListDatasetReq() + + assert req.page == 1 + assert req.page_size == 30 + assert req.orderby == "create_time" + assert req.desc is True + + def test_custom_pagination(self): + """Test custom pagination values""" + req = ListDatasetReq(page=2, page_size=50) + + assert req.page == 2 + assert req.page_size == 50 + + def test_page_minimum_value(self): + """Test page minimum value validation""" + with pytest.raises(ValidationError): + ListDatasetReq(page=0) + + def test_valid_orderby_values(self): + """Test valid orderby values""" + req1 = ListDatasetReq(orderby="create_time") + req2 = ListDatasetReq(orderby="update_time") + + assert req1.orderby == "create_time" + assert req2.orderby == "update_time" + + def test_invalid_orderby_value(self): + """Test invalid orderby value raises error""" + with pytest.raises(ValidationError): + ListDatasetReq(orderby="invalid_field") + + +class TestParserConfig: + """Test cases for ParserConfig validation""" + + def test_default_parser_config(self): + """Test default parser configuration""" + config = ParserConfig() + + assert config.chunk_token_num == 512 + assert config.auto_keywords == 0 + assert config.auto_questions == 0 + + def test_custom_parser_config(self): + """Test custom parser configuration""" + config = ParserConfig( + chunk_token_num=1024, + auto_keywords=5, + auto_questions=3 + ) + + assert config.chunk_token_num == 1024 + assert config.auto_keywords == 5 + assert config.auto_questions == 3 + + def test_chunk_token_num_range(self): + """Test chunk_token_num range validation""" + with pytest.raises(ValidationError): + ParserConfig(chunk_token_num=3000) # Max is 2048 + + def test_raptor_config_integration(self): + """Test raptor config integration""" + config = ParserConfig( + raptor=RaptorConfig(use_raptor=True, max_token=512) + ) + + assert config.raptor.use_raptor is True + assert config.raptor.max_token == 512 + + +class TestRaptorConfig: + """Test cases for RaptorConfig validation""" + + def test_default_raptor_config(self): + """Test default raptor configuration""" + config = RaptorConfig() + + assert config.use_raptor is False + assert config.max_token == 256 + assert config.threshold == 0.1 + + def test_custom_raptor_config(self): + """Test custom raptor configuration""" + config = RaptorConfig( + use_raptor=True, + max_token=512, + threshold=0.2 + ) + + assert config.use_raptor is True + assert config.max_token == 512 + assert config.threshold == 0.2 + + def test_threshold_range(self): + """Test threshold range validation""" + with pytest.raises(ValidationError): + RaptorConfig(threshold=1.5) # Max is 1.0 + + +class TestGraphragConfig: + """Test cases for GraphragConfig validation""" + + def test_default_graphrag_config(self): + """Test default graphrag configuration""" + config = GraphragConfig() + + assert config.use_graphrag is False + assert config.method == "light" + assert "organization" in config.entity_types + + def test_custom_graphrag_config(self): + """Test custom graphrag configuration""" + config = GraphragConfig( + use_graphrag=True, + method="general", + entity_types=["person", "location"] + ) + + assert config.use_graphrag is True + assert config.method == "general" + assert config.entity_types == ["person", "location"] + + def test_invalid_method(self): + """Test invalid method raises error""" + with pytest.raises(ValidationError): + GraphragConfig(method="invalid") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/conftest.py b/test/unit_test/conftest.py new file mode 100644 index 000000000..434588e49 --- /dev/null +++ b/test/unit_test/conftest.py @@ -0,0 +1,23 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from pathlib import Path + +# Add project root to Python path +project_root = Path(__file__).parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) From 5f17a2b6a902a894d28bbbea60be003a3e81f56d Mon Sep 17 00:00:00 2001 From: Mobile-Crest Date: Wed, 3 Dec 2025 13:12:47 +0100 Subject: [PATCH 2/2] follow same test pattern --- test/unit_test/api/test_common_utils.py | 102 ++++-------- test/unit_test/api/test_configs.py | 208 +++++++----------------- 2 files changed, 90 insertions(+), 220 deletions(-) diff --git a/test/unit_test/api/test_common_utils.py b/test/unit_test/api/test_common_utils.py index 05b5fc41c..5caf4ddb9 100644 --- a/test/unit_test/api/test_common_utils.py +++ b/test/unit_test/api/test_common_utils.py @@ -42,79 +42,38 @@ class TestStringToBytes: assert result == input_bytes assert result is input_bytes # Should be the same object - def test_empty_string(self): - """Test conversion of empty string""" - result = string_to_bytes("") - - assert isinstance(result, bytes) - assert result == b"" - assert len(result) == 0 - - def test_unicode_characters(self): - """Test conversion of Unicode characters""" - input_string = "Hello δΈ–η•Œ 🌍" - result = string_to_bytes(input_string) - - assert isinstance(result, bytes) - # Verify it can be decoded back - assert result.decode("utf-8") == input_string - - def test_special_characters(self): - """Test conversion of special characters""" - input_string = "Hello, world! @#$%^&*()" - result = string_to_bytes(input_string) - - assert isinstance(result, bytes) - assert result.decode("utf-8") == input_string - @pytest.mark.parametrize("input_val,expected", [ ("test", b"test"), ("", b""), ("123", b"123"), ("Hello World", b"Hello World"), + ("Hello δΈ–η•Œ 🌍", "Hello δΈ–η•Œ 🌍".encode("utf-8")), + ("Hello, world! @#$%^&*()", b"Hello, world! @#$%^&*()"), + ("Newline\nTab\tQuote\"", b"Newline\nTab\tQuote\""), ]) def test_various_string_inputs(self, input_val, expected): - """Test various string inputs""" + """Test conversion of various string inputs including unicode and special characters""" result = string_to_bytes(input_val) + assert isinstance(result, bytes) assert result == expected class TestBytesToString: """Test cases for bytes_to_string function""" - def test_bytes_input_returns_string(self): - """Test that bytes input is converted to string""" - input_bytes = b"hello world" - result = bytes_to_string(input_bytes) - - assert isinstance(result, str) - assert result == "hello world" - - def test_empty_bytes(self): - """Test conversion of empty bytes""" - result = bytes_to_string(b"") - - assert isinstance(result, str) - assert result == "" - assert len(result) == 0 - - def test_unicode_bytes(self): - """Test conversion of Unicode bytes""" - input_bytes = "Hello δΈ–η•Œ 🌍".encode("utf-8") - result = bytes_to_string(input_bytes) - - assert isinstance(result, str) - assert result == "Hello δΈ–η•Œ 🌍" - @pytest.mark.parametrize("input_bytes,expected", [ + (b"hello world", "hello world"), (b"test", "test"), (b"", ""), (b"123", "123"), (b"Hello World", "Hello World"), + ("Hello δΈ–η•Œ 🌍".encode("utf-8"), "Hello δΈ–η•Œ 🌍"), + (b"Special: @#$%^&*()", "Special: @#$%^&*()"), ]) def test_various_bytes_inputs(self, input_bytes, expected): - """Test various bytes inputs""" + """Test conversion of various bytes inputs including unicode""" result = bytes_to_string(input_bytes) + assert isinstance(result, str) assert result == expected def test_invalid_utf8_raises_error(self): @@ -129,35 +88,32 @@ class TestBytesToString: class TestRoundtripConversion: """Test roundtrip conversions between string and bytes""" - def test_string_to_bytes_to_string(self): - """Test converting string to bytes and back""" - original = "Hello, World! δΈ–η•Œ" - - as_bytes = string_to_bytes(original) - back_to_string = bytes_to_string(as_bytes) - - assert back_to_string == original - - def test_bytes_to_string_to_bytes(self): - """Test converting bytes to string and back""" - original = b"Hello, World!" - - as_string = bytes_to_string(original) - back_to_bytes = string_to_bytes(as_string) - - assert back_to_bytes == original - @pytest.mark.parametrize("test_string", [ "Simple text", + "Hello, World! δΈ–η•Œ", "Unicode: δ½ ε₯½δΈ–η•Œ 🌍", "Special: !@#$%^&*()", "Multiline\nWith\tTabs", "", ]) - def test_roundtrip_various_strings(self, test_string): - """Test roundtrip conversion for various strings""" - result = bytes_to_string(string_to_bytes(test_string)) - assert result == test_string + def test_string_to_bytes_to_string(self, test_string): + """Test converting string to bytes and back for various inputs""" + as_bytes = string_to_bytes(test_string) + back_to_string = bytes_to_string(as_bytes) + assert back_to_string == test_string + + @pytest.mark.parametrize("test_bytes", [ + b"Simple text", + b"Hello, World!", + "Unicode: δ½ ε₯½δΈ–η•Œ 🌍".encode("utf-8"), + b"Special: !@#$%^&*()", + b"", + ]) + def test_bytes_to_string_to_bytes(self, test_bytes): + """Test converting bytes to string and back for various inputs""" + as_string = bytes_to_string(test_bytes) + back_to_bytes = string_to_bytes(as_string) + assert back_to_bytes == test_bytes if __name__ == "__main__": diff --git a/test/unit_test/api/test_configs.py b/test/unit_test/api/test_configs.py index 7bcc5b0b7..309b55aa0 100644 --- a/test/unit_test/api/test_configs.py +++ b/test/unit_test/api/test_configs.py @@ -34,30 +34,31 @@ from api.utils.configs import ( class TestSerializeB64: """Test cases for serialize_b64 function""" - def test_serialize_dict(self): - """Test serialization of a dictionary""" - test_dict = {"key": "value", "number": 42} - result = serialize_b64(test_dict) + @pytest.mark.parametrize("test_data", [ + {"key": "value", "number": 42}, + [1, 2, 3, "test", {"nested": "dict"}], + "Hello, World!", + 12345, + { + "list": [1, 2, 3], + "dict": {"nested": {"deep": "value"}}, + "tuple": (1, 2, 3), + "string": "test", + "number": 42.5 + }, + None, + {}, + [], + ]) + def test_serialize_returns_bytes(self, test_data): + """Test serialization of various data types returns bytes""" + result = serialize_b64(test_data, to_str=False) assert isinstance(result, bytes) # Should be valid base64 decoded = base64.b64decode(result) assert isinstance(decoded, bytes) - def test_serialize_list(self): - """Test serialization of a list""" - test_list = [1, 2, 3, "test", {"nested": "dict"}] - result = serialize_b64(test_list) - - assert isinstance(result, bytes) - - def test_serialize_with_to_str_false(self): - """Test serialization with to_str=False returns bytes""" - test_data = {"test": "data"} - result = serialize_b64(test_data, to_str=False) - - assert isinstance(result, bytes) - def test_serialize_with_to_str_true(self): """Test serialization with to_str=True returns string""" test_data = {"test": "data"} @@ -67,68 +68,15 @@ class TestSerializeB64: # Should be valid base64 string base64.b64decode(result) # Should not raise - def test_serialize_string(self): - """Test serialization of a string""" - test_string = "Hello, World!" - result = serialize_b64(test_string) - - assert isinstance(result, bytes) - - def test_serialize_number(self): - """Test serialization of numbers""" - test_int = 12345 - result = serialize_b64(test_int) - - assert isinstance(result, bytes) - - def test_serialize_complex_nested_structure(self): - """Test serialization of complex nested structures""" - test_data = { - "list": [1, 2, 3], - "dict": {"nested": {"deep": "value"}}, - "tuple": (1, 2, 3), - "string": "test", - "number": 42.5 - } - result = serialize_b64(test_data) - - assert isinstance(result, bytes) - - def test_serialize_none(self): - """Test serialization of None""" - result = serialize_b64(None) - - assert isinstance(result, bytes) - - def test_serialize_empty_dict(self): - """Test serialization of empty dictionary""" - result = serialize_b64({}) - - assert isinstance(result, bytes) - - def test_serialize_empty_list(self): - """Test serialization of empty list""" - result = serialize_b64([]) - - assert isinstance(result, bytes) - class TestDeserializeB64: """Test cases for deserialize_b64 function""" - def test_deserialize_string_input(self): - """Test deserialization with string input""" + @pytest.mark.parametrize("to_str", [True, False]) + def test_deserialize_string_and_bytes_input(self, to_str): + """Test deserialization with both string and bytes input""" test_data = {"key": "value"} - serialized = serialize_b64(test_data, to_str=True) - - result = deserialize_b64(serialized) - - assert result == test_data - - def test_deserialize_bytes_input(self): - """Test deserialization with bytes input""" - test_data = {"key": "value"} - serialized = serialize_b64(test_data, to_str=False) + serialized = serialize_b64(test_data, to_str=to_str) result = deserialize_b64(serialized) @@ -160,22 +108,14 @@ class TestDeserializeB64: assert result == test_data - def test_roundtrip_serialization(self): - """Test complete roundtrip serialization and deserialization""" - test_data = { + @pytest.mark.parametrize("test_data", [ + {"key": "value"}, + { "string": "test", "number": 123, "list": [1, 2, 3], "nested": {"key": "value"} - } - - serialized = serialize_b64(test_data, to_str=True) - deserialized = deserialize_b64(serialized) - - assert deserialized == test_data - - @pytest.mark.parametrize("test_data", [ - {"key": "value"}, + }, [1, 2, 3, 4, 5], "simple string", 42, @@ -184,7 +124,7 @@ class TestDeserializeB64: {"nested": {"deep": {"structure": "value"}}}, ]) def test_roundtrip_various_data_types(self, test_data): - """Test roundtrip for various data types""" + """Test roundtrip serialization and deserialization for various data types""" serialized = serialize_b64(test_data) deserialized = deserialize_b64(serialized) @@ -194,14 +134,19 @@ class TestDeserializeB64: class TestRestrictedUnpickler: """Test cases for RestrictedUnpickler class""" - def test_allows_safe_modules(self): - """Test that safe modules are allowed""" - # Create a simple object that would be in a safe module context - test_data = {"test": "data"} + @pytest.mark.parametrize("test_data", [ + {"test": "data"}, + [1, 2, 3, "test", {"key": "value"}], + {"nested": {"deep": "structure"}}, + [1, 2, 3], + "simple string", + ]) + def test_restricted_loads_with_safe_data(self, test_data): + """Test restricted_loads with various safe data types""" pickled = pickle.dumps(test_data) - # This should work without raising result = restricted_loads(pickled) + assert result == test_data @patch('api.utils.configs.get_base_config') @@ -231,48 +176,42 @@ class TestRestrictedUnpickler: assert 'numpy' in safe_module assert 'rag_flow' in safe_module - def test_restricted_loads_with_safe_data(self): - """Test restricted_loads with safe data""" - test_data = [1, 2, 3, "test", {"key": "value"}] - pickled = pickle.dumps(test_data) - - result = restricted_loads(pickled) - - assert result == test_data - class TestIntegrationScenarios: """Integration tests for serialization/deserialization workflows""" - def test_serialize_deserialize_workflow(self): - """Test complete workflow of serialize and deserialize""" - original_data = { + @pytest.mark.parametrize("to_str,original_data", [ + (True, { "user": "test_user", "settings": { "theme": "dark", "notifications": True }, "items": [1, 2, 3, 4, 5] - } + }), + (False, {"test": "data", "number": 42}), + (True, {}), + (True, { + f"key_{i}": { + "value": i, + "list": list(range(10)), + "nested": {"deep": f"value_{i}"} + } + for i in range(100) + }), + ]) + def test_serialize_deserialize_workflow(self, to_str, original_data): + """Test complete workflow of serialize and deserialize with various data""" + # Serialize + serialized = serialize_b64(original_data, to_str=to_str) - # Serialize to string - serialized_str = serialize_b64(original_data, to_str=True) - assert isinstance(serialized_str, str) + if to_str: + assert isinstance(serialized, str) + else: + assert isinstance(serialized, bytes) # Deserialize back - deserialized = deserialize_b64(serialized_str) - assert deserialized == original_data - - def test_serialize_deserialize_with_bytes(self): - """Test workflow using bytes format""" - original_data = {"test": "data", "number": 42} - - # Serialize to bytes - serialized_bytes = serialize_b64(original_data, to_str=False) - assert isinstance(serialized_bytes, bytes) - - # Deserialize back - deserialized = deserialize_b64(serialized_bytes) + deserialized = deserialize_b64(serialized) assert deserialized == original_data @patch('api.utils.configs.get_base_config') @@ -287,31 +226,6 @@ class TestIntegrationScenarios: assert result == test_data - def test_empty_data_workflow(self): - """Test workflow with empty data""" - empty_dict = {} - - serialized = serialize_b64(empty_dict, to_str=True) - deserialized = deserialize_b64(serialized) - - assert deserialized == empty_dict - - def test_large_data_workflow(self): - """Test workflow with larger data structures""" - large_data = { - f"key_{i}": { - "value": i, - "list": list(range(10)), - "nested": {"deep": f"value_{i}"} - } - for i in range(100) - } - - serialized = serialize_b64(large_data, to_str=True) - deserialized = deserialize_b64(serialized) - - assert deserialized == large_data - if __name__ == "__main__": pytest.main([__file__, "-v"])