diff --git a/test/unit_test/api/__init__.py b/test/unit_test/api/__init__.py new file mode 100644 index 000000000..177b91dd0 --- /dev/null +++ b/test/unit_test/api/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/test/unit_test/api/test_common_utils.py b/test/unit_test/api/test_common_utils.py new file mode 100644 index 000000000..5caf4ddb9 --- /dev/null +++ b/test/unit_test/api/test_common_utils.py @@ -0,0 +1,120 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.common module. +""" + +import pytest +from api.utils.common import string_to_bytes, bytes_to_string + + +class TestStringToBytes: + """Test cases for string_to_bytes function""" + + def test_string_input_returns_bytes(self): + """Test that string input is converted to bytes""" + input_string = "hello world" + result = string_to_bytes(input_string) + + assert isinstance(result, bytes) + assert result == b"hello world" + + def test_bytes_input_returns_same_bytes(self): + """Test that bytes input is returned unchanged""" + input_bytes = b"hello world" + result = string_to_bytes(input_bytes) + + assert isinstance(result, bytes) + assert result == input_bytes + assert result is input_bytes # Should be the same object + + @pytest.mark.parametrize("input_val,expected", [ + ("test", b"test"), + ("", b""), + ("123", b"123"), + ("Hello World", b"Hello World"), + ("Hello δΈ–η•Œ 🌍", "Hello δΈ–η•Œ 🌍".encode("utf-8")), + ("Hello, world! @#$%^&*()", b"Hello, world! @#$%^&*()"), + ("Newline\nTab\tQuote\"", b"Newline\nTab\tQuote\""), + ]) + def test_various_string_inputs(self, input_val, expected): + """Test conversion of various string inputs including unicode and special characters""" + result = string_to_bytes(input_val) + assert isinstance(result, bytes) + assert result == expected + + +class TestBytesToString: + """Test cases for bytes_to_string function""" + + @pytest.mark.parametrize("input_bytes,expected", [ + (b"hello world", "hello world"), + (b"test", "test"), + (b"", ""), + (b"123", "123"), + (b"Hello World", "Hello World"), + ("Hello δΈ–η•Œ 🌍".encode("utf-8"), "Hello δΈ–η•Œ 🌍"), + (b"Special: @#$%^&*()", "Special: @#$%^&*()"), + ]) + def test_various_bytes_inputs(self, input_bytes, expected): + """Test conversion of various bytes inputs including unicode""" + result = bytes_to_string(input_bytes) + assert isinstance(result, str) + assert result == expected + + def test_invalid_utf8_raises_error(self): + """Test that invalid UTF-8 bytes raise an error""" + # Invalid UTF-8 sequence + invalid_bytes = b"\xff\xfe" + + with pytest.raises(UnicodeDecodeError): + bytes_to_string(invalid_bytes) + + +class TestRoundtripConversion: + """Test roundtrip conversions between string and bytes""" + + @pytest.mark.parametrize("test_string", [ + "Simple text", + "Hello, World! δΈ–η•Œ", + "Unicode: δ½ ε₯½δΈ–η•Œ 🌍", + "Special: !@#$%^&*()", + "Multiline\nWith\tTabs", + "", + ]) + def test_string_to_bytes_to_string(self, test_string): + """Test converting string to bytes and back for various inputs""" + as_bytes = string_to_bytes(test_string) + back_to_string = bytes_to_string(as_bytes) + assert back_to_string == test_string + + @pytest.mark.parametrize("test_bytes", [ + b"Simple text", + b"Hello, World!", + "Unicode: δ½ ε₯½δΈ–η•Œ 🌍".encode("utf-8"), + b"Special: !@#$%^&*()", + b"", + ]) + def test_bytes_to_string_to_bytes(self, test_bytes): + """Test converting bytes to string and back for various inputs""" + as_string = bytes_to_string(test_bytes) + back_to_bytes = string_to_bytes(as_string) + assert back_to_bytes == test_bytes + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_configs.py b/test/unit_test/api/test_configs.py new file mode 100644 index 000000000..309b55aa0 --- /dev/null +++ b/test/unit_test/api/test_configs.py @@ -0,0 +1,231 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.configs module. +""" + +import pytest +import pickle +import base64 +from unittest.mock import patch +from api.utils.configs import ( + serialize_b64, + deserialize_b64, + RestrictedUnpickler, + restricted_loads, + safe_module +) + + +class TestSerializeB64: + """Test cases for serialize_b64 function""" + + @pytest.mark.parametrize("test_data", [ + {"key": "value", "number": 42}, + [1, 2, 3, "test", {"nested": "dict"}], + "Hello, World!", + 12345, + { + "list": [1, 2, 3], + "dict": {"nested": {"deep": "value"}}, + "tuple": (1, 2, 3), + "string": "test", + "number": 42.5 + }, + None, + {}, + [], + ]) + def test_serialize_returns_bytes(self, test_data): + """Test serialization of various data types returns bytes""" + result = serialize_b64(test_data, to_str=False) + + assert isinstance(result, bytes) + # Should be valid base64 + decoded = base64.b64decode(result) + assert isinstance(decoded, bytes) + + def test_serialize_with_to_str_true(self): + """Test serialization with to_str=True returns string""" + test_data = {"test": "data"} + result = serialize_b64(test_data, to_str=True) + + assert isinstance(result, str) + # Should be valid base64 string + base64.b64decode(result) # Should not raise + + +class TestDeserializeB64: + """Test cases for deserialize_b64 function""" + + @pytest.mark.parametrize("to_str", [True, False]) + def test_deserialize_string_and_bytes_input(self, to_str): + """Test deserialization with both string and bytes input""" + test_data = {"key": "value"} + serialized = serialize_b64(test_data, to_str=to_str) + + result = deserialize_b64(serialized) + + assert result == test_data + + @patch('api.utils.configs.get_base_config') + def test_deserialize_with_safe_module_disabled(self, mock_config): + """Test deserialization with safe module checking disabled""" + mock_config.return_value = False + + test_data = {"test": "data"} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + mock_config.assert_called_once_with('use_deserialize_safe_module', False) + + @patch('api.utils.configs.get_base_config') + def test_deserialize_with_safe_module_enabled(self, mock_config): + """Test deserialization with safe module checking enabled""" + mock_config.return_value = True + + # Simple data that doesn't require unsafe modules + test_data = {"test": "data", "number": 42} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + + @pytest.mark.parametrize("test_data", [ + {"key": "value"}, + { + "string": "test", + "number": 123, + "list": [1, 2, 3], + "nested": {"key": "value"} + }, + [1, 2, 3, 4, 5], + "simple string", + 42, + 3.14, + None, + {"nested": {"deep": {"structure": "value"}}}, + ]) + def test_roundtrip_various_data_types(self, test_data): + """Test roundtrip serialization and deserialization for various data types""" + serialized = serialize_b64(test_data) + deserialized = deserialize_b64(serialized) + + assert deserialized == test_data + + +class TestRestrictedUnpickler: + """Test cases for RestrictedUnpickler class""" + + @pytest.mark.parametrize("test_data", [ + {"test": "data"}, + [1, 2, 3, "test", {"key": "value"}], + {"nested": {"deep": "structure"}}, + [1, 2, 3], + "simple string", + ]) + def test_restricted_loads_with_safe_data(self, test_data): + """Test restricted_loads with various safe data types""" + pickled = pickle.dumps(test_data) + + result = restricted_loads(pickled) + + assert result == test_data + + @patch('api.utils.configs.get_base_config') + def test_blocks_unsafe_modules(self, mock_config): + """Test that unsafe modules are blocked""" + mock_config.return_value = True + + # Try to pickle something from an unsafe module + # We'll simulate this by creating a pickle that references os module + class UnsafeClass: + __module__ = 'os.path' + def __reduce__(self): + return (eval, ("1+1",)) + + try: + unsafe_obj = UnsafeClass() + pickled = pickle.dumps(unsafe_obj) + + with pytest.raises(pickle.UnpicklingError): + restricted_loads(pickled) + except: + # If we can't create the unsafe pickle, skip this test + pytest.skip("Unable to create unsafe pickle for testing") + + def test_safe_module_set_contains_expected_modules(self): + """Test that safe_module set contains expected modules""" + assert 'numpy' in safe_module + assert 'rag_flow' in safe_module + + +class TestIntegrationScenarios: + """Integration tests for serialization/deserialization workflows""" + + @pytest.mark.parametrize("to_str,original_data", [ + (True, { + "user": "test_user", + "settings": { + "theme": "dark", + "notifications": True + }, + "items": [1, 2, 3, 4, 5] + }), + (False, {"test": "data", "number": 42}), + (True, {}), + (True, { + f"key_{i}": { + "value": i, + "list": list(range(10)), + "nested": {"deep": f"value_{i}"} + } + for i in range(100) + }), + ]) + def test_serialize_deserialize_workflow(self, to_str, original_data): + """Test complete workflow of serialize and deserialize with various data""" + # Serialize + serialized = serialize_b64(original_data, to_str=to_str) + + if to_str: + assert isinstance(serialized, str) + else: + assert isinstance(serialized, bytes) + + # Deserialize back + deserialized = deserialize_b64(serialized) + assert deserialized == original_data + + @patch('api.utils.configs.get_base_config') + def test_safe_deserialization_workflow(self, mock_config): + """Test safe deserialization workflow""" + mock_config.return_value = True + + test_data = {"safe": "data"} + serialized = serialize_b64(test_data, to_str=True) + + result = deserialize_b64(serialized) + + assert result == test_data + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_json_encode.py b/test/unit_test/api/test_json_encode.py new file mode 100644 index 000000000..0114137da --- /dev/null +++ b/test/unit_test/api/test_json_encode.py @@ -0,0 +1,452 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.json_encode module. +""" + +import pytest +import json +import datetime +from enum import Enum, IntEnum +from api.utils.json_encode import ( + BaseType, + CustomJSONEncoder, + json_dumps, + json_loads +) + + +class TestBaseTypeToDict: + """Test cases for BaseType.to_dict method""" + + def test_simple_object_to_dict(self): + """Test converting simple object to dictionary""" + class SimpleType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = SimpleType() + result = obj.to_dict() + + assert isinstance(result, dict) + assert result == {"name": "test", "value": 42} + + def test_private_attributes_excluded(self): + """Test that private attributes (starting with _) are stripped""" + class PrivateType(BaseType): + def __init__(self): + self._private = "hidden" + self._another = "also hidden" + self.public = "visible" + + obj = PrivateType() + result = obj.to_dict() + + # Private attributes should have _ stripped in keys + assert "private" in result + assert "another" in result + assert "public" in result + assert result["private"] == "hidden" + assert result["another"] == "also hidden" + assert result["public"] == "visible" + + def test_empty_object(self): + """Test converting empty object to dictionary""" + class EmptyType(BaseType): + pass + + obj = EmptyType() + result = obj.to_dict() + + assert isinstance(result, dict) + assert len(result) == 0 + + def test_nested_object_to_dict(self): + """Test converting object with nested values""" + class NestedType(BaseType): + def __init__(self): + self.data = {"key": "value"} + self.items = [1, 2, 3] + self.text = "test" + + obj = NestedType() + result = obj.to_dict() + + assert result["data"] == {"key": "value"} + assert result["items"] == [1, 2, 3] + assert result["text"] == "test" + + +class TestBaseTypeToDictWithType: + """Test cases for BaseType.to_dict_with_type method""" + + def test_includes_type_information(self): + """Test that type information is included""" + class TypedObject(BaseType): + def __init__(self): + self.value = "test" + + obj = TypedObject() + result = obj.to_dict_with_type() + + assert "type" in result + assert "data" in result + assert result["type"] == "TypedObject" + + def test_includes_module_information(self): + """Test that module information is included""" + class ModuleObject(BaseType): + def __init__(self): + self.value = "test" + + obj = ModuleObject() + result = obj.to_dict_with_type() + + assert "module" in result + assert result["module"] is not None + + def test_nested_objects_with_types(self): + """Test nested BaseType objects include type info""" + class InnerType(BaseType): + def __init__(self): + self.inner_value = "inner" + + class OuterType(BaseType): + def __init__(self): + self.nested = InnerType() + self.value = "outer" + + obj = OuterType() + result = obj.to_dict_with_type() + + assert result["type"] == "OuterType" + assert "data" in result + assert "nested" in result["data"] + assert result["data"]["nested"]["type"] == "InnerType" + + def test_list_handling(self): + """Test handling of lists in to_dict_with_type""" + class ListType(BaseType): + def __init__(self): + self.items = [1, 2, 3] + + obj = ListType() + result = obj.to_dict_with_type() + + assert "items" in result["data"] + items_data = result["data"]["items"] + assert items_data["type"] == "list" + assert isinstance(items_data["data"], list) + + def test_dict_handling(self): + """Test handling of dictionaries in to_dict_with_type""" + class DictType(BaseType): + def __init__(self): + self.config = {"key": "value"} + + obj = DictType() + result = obj.to_dict_with_type() + + assert "config" in result["data"] + config_data = result["data"]["config"] + assert config_data["type"] == "dict" + + +class TestCustomJSONEncoder: + """Test cases for CustomJSONEncoder class""" + + def test_encode_datetime(self): + """Test encoding of datetime objects""" + dt = datetime.datetime(2025, 12, 3, 14, 30, 45) + result = json.dumps(dt, cls=CustomJSONEncoder) + + assert result == '"2025-12-03 14:30:45"' + + def test_encode_date(self): + """Test encoding of date objects""" + d = datetime.date(2025, 12, 3) + result = json.dumps(d, cls=CustomJSONEncoder) + + assert result == '"2025-12-03"' + + def test_encode_timedelta(self): + """Test encoding of timedelta objects""" + td = datetime.timedelta(days=1, hours=2, minutes=30) + result = json.dumps(td, cls=CustomJSONEncoder) + + assert isinstance(json.loads(result), str) + assert "1 day" in json.loads(result) + + def test_encode_enum(self): + """Test encoding of Enum objects""" + class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + result = json.dumps(Color.RED, cls=CustomJSONEncoder) + + assert result == '"red"' + + def test_encode_int_enum(self): + """Test encoding of IntEnum objects""" + class Priority(IntEnum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + result = json.dumps(Priority.HIGH, cls=CustomJSONEncoder) + + assert result == '3' + + def test_encode_set(self): + """Test encoding of set objects""" + test_set = {1, 2, 3, 4, 5} + result = json.dumps(test_set, cls=CustomJSONEncoder) + + # Set should be converted to list + decoded = json.loads(result) + assert isinstance(decoded, list) + assert set(decoded) == test_set + + def test_encode_basetype_object(self): + """Test encoding of BaseType objects""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = TestType() + result = json.dumps(obj, cls=CustomJSONEncoder) + + decoded = json.loads(result) + assert decoded["name"] == "test" + assert decoded["value"] == 42 + + def test_encode_basetype_with_type_flag(self): + """Test encoding BaseType with with_type flag""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + encoder = CustomJSONEncoder(with_type=True) + result = json.dumps(obj, cls=CustomJSONEncoder, with_type=True) + + decoded = json.loads(result) + assert "type" in decoded + assert "data" in decoded + + def test_encode_type_object(self): + """Test encoding of type objects""" + result = json.dumps(str, cls=CustomJSONEncoder) + + assert result == '"str"' + + def test_encode_nested_structures(self): + """Test encoding of nested structures with various types""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + + data = { + "datetime": datetime.datetime(2025, 12, 3, 12, 0, 0), + "date": datetime.date(2025, 12, 3), + "set": {1, 2, 3}, + "object": TestType(), + "list": [1, 2, 3] + } + + result = json.dumps(data, cls=CustomJSONEncoder) + decoded = json.loads(result) + + assert decoded["datetime"] == "2025-12-03 12:00:00" + assert decoded["date"] == "2025-12-03" + assert isinstance(decoded["set"], list) + assert decoded["object"]["name"] == "test" + + +class TestJsonDumps: + """Test cases for json_dumps function""" + + def test_json_dumps_basic(self): + """Test basic json_dumps functionality""" + data = {"key": "value", "number": 42} + result = json_dumps(data) + + assert isinstance(result, str) + assert json.loads(result) == data + + def test_json_dumps_with_byte_false(self): + """Test json_dumps with byte=False returns string""" + data = {"test": "data"} + result = json_dumps(data, byte=False) + + assert isinstance(result, str) + + def test_json_dumps_with_byte_true(self): + """Test json_dumps with byte=True returns bytes""" + data = {"test": "data"} + result = json_dumps(data, byte=True) + + assert isinstance(result, bytes) + + def test_json_dumps_with_indent(self): + """Test json_dumps with indentation""" + data = {"key": "value"} + result = json_dumps(data, indent=2) + + assert isinstance(result, str) + assert "\n" in result # Indented JSON has newlines + + def test_json_dumps_with_type_false(self): + """Test json_dumps with with_type=False""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + result = json_dumps(obj, with_type=False) + + decoded = json.loads(result) + assert "type" not in decoded + assert decoded["value"] == "test" + + def test_json_dumps_with_type_true(self): + """Test json_dumps with with_type=True""" + class TestType(BaseType): + def __init__(self): + self.value = "test" + + obj = TestType() + result = json_dumps(obj, with_type=True) + + decoded = json.loads(result) + assert "type" in decoded + assert "data" in decoded + + def test_json_dumps_datetime(self): + """Test json_dumps with datetime objects""" + data = { + "timestamp": datetime.datetime(2025, 12, 3, 15, 30, 0) + } + result = json_dumps(data) + + decoded = json.loads(result) + assert decoded["timestamp"] == "2025-12-03 15:30:00" + + +class TestJsonLoads: + """Test cases for json_loads function""" + + def test_json_loads_string_input(self): + """Test json_loads with string input""" + json_string = '{"key": "value", "number": 42}' + result = json_loads(json_string) + + assert isinstance(result, dict) + assert result["key"] == "value" + assert result["number"] == 42 + + def test_json_loads_bytes_input(self): + """Test json_loads with bytes input""" + json_bytes = b'{"key": "value"}' + result = json_loads(json_bytes) + + assert isinstance(result, dict) + assert result["key"] == "value" + + def test_json_loads_with_object_hook(self): + """Test json_loads with object_hook parameter""" + def custom_hook(obj): + if "special" in obj: + obj["processed"] = True + return obj + + json_string = '{"special": "value"}' + result = json_loads(json_string, object_hook=custom_hook) + + assert result["processed"] is True + + def test_json_loads_empty_object(self): + """Test json_loads with empty object""" + result = json_loads('{}') + + assert isinstance(result, dict) + assert len(result) == 0 + + def test_json_loads_array(self): + """Test json_loads with array""" + result = json_loads('[1, 2, 3, 4, 5]') + + assert isinstance(result, list) + assert result == [1, 2, 3, 4, 5] + + +class TestRoundtripConversion: + """Test roundtrip conversions between dumps and loads""" + + def test_roundtrip_dict(self): + """Test roundtrip conversion of dictionary""" + original = {"key": "value", "number": 42, "list": [1, 2, 3]} + + dumped = json_dumps(original) + loaded = json_loads(dumped) + + assert loaded == original + + def test_roundtrip_with_bytes(self): + """Test roundtrip with byte conversion""" + original = {"test": "data"} + + dumped = json_dumps(original, byte=True) + loaded = json_loads(dumped) + + assert loaded == original + + def test_roundtrip_basetype(self): + """Test roundtrip with BaseType object""" + class TestType(BaseType): + def __init__(self): + self.name = "test" + self.value = 42 + + obj = TestType() + + dumped = json_dumps(obj) + loaded = json_loads(dumped) + + assert loaded["name"] == "test" + assert loaded["value"] == 42 + + @pytest.mark.parametrize("test_data", [ + {"simple": "dict"}, + [1, 2, 3, 4, 5], + {"nested": {"deep": {"structure": "value"}}}, + {"mixed": [1, "two", 3.0, {"four": 4}]}, + ]) + def test_roundtrip_various_structures(self, test_data): + """Test roundtrip for various data structures""" + dumped = json_dumps(test_data) + loaded = json_loads(dumped) + + assert loaded == test_data + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/api/test_validation_utils.py b/test/unit_test/api/test_validation_utils.py new file mode 100644 index 000000000..32aba294c --- /dev/null +++ b/test/unit_test/api/test_validation_utils.py @@ -0,0 +1,788 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for api.utils.validation_utils module. +""" + +import pytest +from unittest.mock import Mock, AsyncMock +from uuid import UUID, uuid1 +from pydantic import BaseModel, Field, ValidationError +from werkzeug.exceptions import BadRequest, UnsupportedMediaType + +from api.utils.validation_utils import ( + validate_and_parse_json_request, + validate_and_parse_request_args, + format_validation_error_message, + normalize_str, + validate_uuid1_hex, + CreateDatasetReq, + UpdateDatasetReq, + DeleteDatasetReq, + ListDatasetReq, + ParserConfig, + RaptorConfig, + GraphragConfig, +) + + +class TestNormalizeStr: + """Test cases for normalize_str function""" + + def test_normalize_string_with_whitespace(self): + """Test normalization of string with leading/trailing whitespace""" + result = normalize_str(" Admin ") + assert result == "admin" + + def test_normalize_string_uppercase(self): + """Test normalization converts to lowercase""" + result = normalize_str("UPPERCASE") + assert result == "uppercase" + + def test_normalize_mixed_case(self): + """Test normalization of mixed case string""" + result = normalize_str(" MiXeD CaSe ") + assert result == "mixed case" + + def test_normalize_empty_string(self): + """Test normalization of empty string""" + result = normalize_str("") + assert result == "" + + def test_normalize_whitespace_only(self): + """Test normalization of whitespace-only string""" + result = normalize_str(" ") + assert result == "" + + def test_preserve_non_string_integer(self): + """Test that integers are preserved""" + result = normalize_str(42) + assert result == 42 + + def test_preserve_non_string_none(self): + """Test that None is preserved""" + result = normalize_str(None) + assert result is None + + def test_preserve_non_string_list(self): + """Test that lists are preserved""" + input_list = ["User", "Admin"] + result = normalize_str(input_list) + assert result == input_list + + def test_preserve_non_string_dict(self): + """Test that dicts are preserved""" + input_dict = {"role": "Admin"} + result = normalize_str(input_dict) + assert result == input_dict + + @pytest.mark.parametrize("input_val,expected", [ + ("ReadOnly", "readonly"), + (" ADMIN ", "admin"), + ("User", "user"), + ("", ""), + (123, 123), + (False, False), + (0, 0), + ]) + def test_various_inputs(self, input_val, expected): + """Test various input types""" + result = normalize_str(input_val) + assert result == expected + + +class TestValidateUuid1Hex: + """Test cases for validate_uuid1_hex function""" + + def test_valid_uuid1_string(self): + """Test validation of valid UUID1 string""" + uuid1_obj = uuid1() + uuid1_str = str(uuid1_obj) + + result = validate_uuid1_hex(uuid1_str) + + assert isinstance(result, str) + assert len(result) == 32 + assert result == uuid1_obj.hex + + def test_valid_uuid1_object(self): + """Test validation of valid UUID1 object""" + uuid1_obj = uuid1() + + result = validate_uuid1_hex(uuid1_obj) + + assert isinstance(result, str) + assert result == uuid1_obj.hex + + def test_uuid1_hex_no_hyphens(self): + """Test that result has no hyphens""" + uuid1_obj = uuid1() + result = validate_uuid1_hex(uuid1_obj) + + assert "-" not in result + + def test_invalid_uuid_string_raises_error(self): + """Test that invalid UUID string raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError) as exc_info: + validate_uuid1_hex("not-a-uuid") + + assert exc_info.value.type == "invalid_UUID1_format" + + def test_non_uuid1_version_raises_error(self): + """Test that non-UUID1 version raises error""" + from pydantic_core import PydanticCustomError + from uuid import uuid4 + + uuid4_obj = uuid4() + + with pytest.raises(PydanticCustomError) as exc_info: + validate_uuid1_hex(uuid4_obj) + + assert exc_info.value.type == "invalid_UUID1_format" + + def test_integer_input_raises_error(self): + """Test that integer input raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError): + validate_uuid1_hex(12345) + + def test_none_input_raises_error(self): + """Test that None input raises error""" + from pydantic_core import PydanticCustomError + + with pytest.raises(PydanticCustomError): + validate_uuid1_hex(None) + + +class TestFormatValidationErrorMessage: + """Test cases for format_validation_error_message function""" + + def test_single_validation_error(self): + """Test formatting of single validation error""" + class TestModel(BaseModel): + name: str + + try: + TestModel(name=123) + except ValidationError as e: + result = format_validation_error_message(e) + + assert "Field: " in result + assert "Message:" in result + assert "Value: <123>" in result + + def test_multiple_validation_errors(self): + """Test formatting of multiple validation errors""" + class TestModel(BaseModel): + name: str + age: int + + try: + TestModel(name=123, age="not_an_int") + except ValidationError as e: + result = format_validation_error_message(e) + + assert "Field: " in result + assert "Field: " in result + assert "\n" in result # Multiple errors separated by newlines + + def test_long_value_truncation(self): + """Test that long values are truncated""" + class TestModel(BaseModel): + text: str + + long_value = "x" * 200 + + try: + TestModel(text=123) # Wrong type to trigger error + except ValidationError as e: + # Manually create error with long value + pass + + # Create a model with max_length constraint + class TestModel2(BaseModel): + text: str = Field(max_length=10) + + try: + TestModel2(text="x" * 200) + except ValidationError as e: + result = format_validation_error_message(e) + + # Check that value is truncated + assert "..." in result or len(result) < 500 + + def test_nested_field_error(self): + """Test formatting of nested field errors""" + class NestedModel(BaseModel): + value: int + + class ParentModel(BaseModel): + nested: NestedModel + + try: + ParentModel(nested={"value": "not_an_int"}) + except ValidationError as e: + result = format_validation_error_message(e) + + assert "nested.value" in result + + +class TestValidateAndParseRequestArgs: + """Test cases for validate_and_parse_request_args function""" + + def test_valid_request_args(self): + """Test validation of valid request arguments""" + class TestValidator(BaseModel): + param1: str + param2: int = 10 + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"param1": "value", "param2": "20"} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert error is None + assert result is not None + assert result["param1"] == "value" + assert result["param2"] == 20 + + def test_missing_required_field(self): + """Test validation with missing required field""" + class TestValidator(BaseModel): + required_field: str + + mock_request = Mock() + mock_request.args.to_dict.return_value = {} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "required_field" in error + + def test_with_extras_parameter(self): + """Test validation with extras parameter""" + class TestValidator(BaseModel): + param1: str + internal_id: int + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"param1": "value"} + + result, error = validate_and_parse_request_args( + mock_request, + TestValidator, + extras={"internal_id": 123} + ) + + assert error is None + assert result is not None + assert result["param1"] == "value" + assert "internal_id" not in result # Extras should be removed + + def test_type_conversion(self): + """Test that Pydantic performs type conversion""" + class TestValidator(BaseModel): + number: int + + mock_request = Mock() + mock_request.args.to_dict.return_value = {"number": "42"} + + result, error = validate_and_parse_request_args(mock_request, TestValidator) + + assert error is None + assert result["number"] == 42 + assert isinstance(result["number"], int) + + +class TestValidateAndParseJsonRequest: + """Test cases for validate_and_parse_json_request function""" + + @pytest.mark.anyio + async def test_valid_json_request(self): + """Test validation of valid JSON request""" + class TestValidator(BaseModel): + name: str + value: int + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test", "value": 42}) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert error is None + assert result is not None + assert result["name"] == "test" + assert result["value"] == 42 + + @pytest.mark.anyio + async def test_unsupported_content_type(self): + """Test handling of unsupported content type""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(side_effect=UnsupportedMediaType()) + mock_request.content_type = "text/xml" + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Unsupported content type" in error + assert "text/xml" in error + + @pytest.mark.anyio + async def test_malformed_json(self): + """Test handling of malformed JSON""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(side_effect=BadRequest()) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Malformed JSON syntax" in error + + @pytest.mark.anyio + async def test_invalid_payload_type(self): + """Test handling of non-dict payload""" + class TestValidator(BaseModel): + name: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value=["not", "a", "dict"]) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Invalid request payload" in error + assert "list" in error + + @pytest.mark.anyio + async def test_validation_error(self): + """Test handling of Pydantic validation errors""" + class TestValidator(BaseModel): + name: str + age: int + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": 123, "age": "not_int"}) + + result, error = await validate_and_parse_json_request(mock_request, TestValidator) + + assert result is None + assert error is not None + assert "Field:" in error + + @pytest.mark.anyio + async def test_with_extras_parameter(self): + """Test validation with extras parameter""" + class TestValidator(BaseModel): + name: str + user_id: str + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test"}) + + result, error = await validate_and_parse_json_request( + mock_request, + TestValidator, + extras={"user_id": "user_123"} + ) + + assert error is None + assert result is not None + assert result["name"] == "test" + assert "user_id" not in result # Extras should be removed + + @pytest.mark.anyio + async def test_exclude_unset_parameter(self): + """Test exclude_unset parameter""" + class TestValidator(BaseModel): + name: str + optional: str = "default" + + mock_request = AsyncMock() + mock_request.get_json = AsyncMock(return_value={"name": "test"}) + + result, error = await validate_and_parse_json_request( + mock_request, + TestValidator, + exclude_unset=True + ) + + assert error is None + assert result is not None + assert "name" in result + assert "optional" not in result # Not set, should be excluded + + +class TestCreateDatasetReq: + """Test cases for CreateDatasetReq validation""" + + def test_valid_dataset_creation(self): + """Test valid dataset creation request""" + data = { + "name": "Test Dataset", + "embedding_model": "text-embedding-3-large@openai" + } + + dataset = CreateDatasetReq(**data) + + assert dataset.name == "Test Dataset" + assert dataset.embedding_model == "text-embedding-3-large@openai" + + def test_name_whitespace_stripping(self): + """Test that name whitespace is stripped""" + data = { + "name": " Test Dataset " + } + + dataset = CreateDatasetReq(**data) + + assert dataset.name == "Test Dataset" + + def test_empty_name_raises_error(self): + """Test that empty name raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq(name="") + + def test_invalid_embedding_model_format(self): + """Test that invalid embedding model format raises error""" + with pytest.raises(ValidationError) as exc_info: + CreateDatasetReq( + name="Test", + embedding_model="invalid_model" + ) + + assert "format_invalid" in str(exc_info.value) + + def test_embedding_model_without_provider(self): + """Test embedding model without provider raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + embedding_model="model_name@" + ) + + def test_valid_avatar_base64(self): + """Test valid base64 avatar""" + data = { + "name": "Test", + "avatar": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + } + + dataset = CreateDatasetReq(**data) + + assert dataset.avatar.startswith("data:image/png") + + def test_invalid_avatar_mime_type(self): + """Test invalid avatar MIME type raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + avatar="data:video/mp4;base64,abc123" + ) + + def test_avatar_missing_data_prefix(self): + """Test avatar missing data: prefix raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + avatar="image/png;base64,abc123" + ) + + def test_default_chunk_method(self): + """Test default chunk_method is set to 'naive'""" + dataset = CreateDatasetReq(name="Test") + + assert dataset.chunk_method == "naive" + + def test_valid_chunk_methods(self): + """Test various valid chunk methods""" + valid_methods = ["naive", "book", "email", "laws", "manual", "one", + "paper", "picture", "presentation", "qa", "table", "tag"] + + for method in valid_methods: + dataset = CreateDatasetReq(name="Test", chunk_method=method) + assert dataset.chunk_method == method + + def test_invalid_chunk_method(self): + """Test invalid chunk method raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq(name="Test", chunk_method="invalid_method") + + def test_pipeline_id_validation(self): + """Test pipeline_id validation""" + # Valid 32-char hex + dataset = CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="a" * 32 + ) + assert dataset.pipeline_id == "a" * 32 + + def test_pipeline_id_wrong_length(self): + """Test pipeline_id with wrong length raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="abc" # Too short + ) + + def test_pipeline_id_non_hex(self): + """Test pipeline_id with non-hex characters raises error""" + with pytest.raises(ValidationError): + CreateDatasetReq( + name="Test", + parse_type=1, + pipeline_id="g" * 32 # 'g' is not hex + ) + + +class TestUpdateDatasetReq: + """Test cases for UpdateDatasetReq validation""" + + def test_valid_update_request(self): + """Test valid dataset update request""" + uuid1_obj = uuid1() + + data = { + "dataset_id": str(uuid1_obj), + "name": "Updated Dataset" + } + + dataset = UpdateDatasetReq(**data) + + assert dataset.dataset_id == uuid1_obj.hex + assert dataset.name == "Updated Dataset" + + def test_dataset_id_uuid1_validation(self): + """Test that dataset_id must be UUID1""" + from uuid import uuid4 + + with pytest.raises(ValidationError): + UpdateDatasetReq( + dataset_id=str(uuid4()), + name="Test" + ) + + def test_pagerank_validation(self): + """Test pagerank field validation""" + uuid1_obj = uuid1() + + dataset = UpdateDatasetReq( + dataset_id=str(uuid1_obj), + name="Test", + pagerank=50 + ) + + assert dataset.pagerank == 50 + + def test_pagerank_out_of_range(self): + """Test pagerank out of range raises error""" + uuid1_obj = uuid1() + + with pytest.raises(ValidationError): + UpdateDatasetReq( + dataset_id=str(uuid1_obj), + name="Test", + pagerank=101 # Max is 100 + ) + + +class TestDeleteDatasetReq: + """Test cases for DeleteDatasetReq validation""" + + def test_valid_delete_request(self): + """Test valid delete request""" + uuid1_obj1 = uuid1() + uuid1_obj2 = uuid1() + + req = DeleteDatasetReq(ids=[str(uuid1_obj1), str(uuid1_obj2)]) + + assert len(req.ids) == 2 + assert uuid1_obj1.hex in req.ids + assert uuid1_obj2.hex in req.ids + + def test_duplicate_ids_raises_error(self): + """Test that duplicate IDs raise error""" + uuid1_obj = uuid1() + + with pytest.raises(ValidationError) as exc_info: + DeleteDatasetReq(ids=[str(uuid1_obj), str(uuid1_obj)]) + + assert "duplicate" in str(exc_info.value).lower() + + def test_empty_ids_list(self): + """Test empty IDs list""" + req = DeleteDatasetReq(ids=[]) + + assert req.ids == [] + + def test_none_ids(self): + """Test None IDs""" + req = DeleteDatasetReq(ids=None) + + assert req.ids is None + + +class TestListDatasetReq: + """Test cases for ListDatasetReq validation""" + + def test_default_values(self): + """Test default values for list request""" + req = ListDatasetReq() + + assert req.page == 1 + assert req.page_size == 30 + assert req.orderby == "create_time" + assert req.desc is True + + def test_custom_pagination(self): + """Test custom pagination values""" + req = ListDatasetReq(page=2, page_size=50) + + assert req.page == 2 + assert req.page_size == 50 + + def test_page_minimum_value(self): + """Test page minimum value validation""" + with pytest.raises(ValidationError): + ListDatasetReq(page=0) + + def test_valid_orderby_values(self): + """Test valid orderby values""" + req1 = ListDatasetReq(orderby="create_time") + req2 = ListDatasetReq(orderby="update_time") + + assert req1.orderby == "create_time" + assert req2.orderby == "update_time" + + def test_invalid_orderby_value(self): + """Test invalid orderby value raises error""" + with pytest.raises(ValidationError): + ListDatasetReq(orderby="invalid_field") + + +class TestParserConfig: + """Test cases for ParserConfig validation""" + + def test_default_parser_config(self): + """Test default parser configuration""" + config = ParserConfig() + + assert config.chunk_token_num == 512 + assert config.auto_keywords == 0 + assert config.auto_questions == 0 + + def test_custom_parser_config(self): + """Test custom parser configuration""" + config = ParserConfig( + chunk_token_num=1024, + auto_keywords=5, + auto_questions=3 + ) + + assert config.chunk_token_num == 1024 + assert config.auto_keywords == 5 + assert config.auto_questions == 3 + + def test_chunk_token_num_range(self): + """Test chunk_token_num range validation""" + with pytest.raises(ValidationError): + ParserConfig(chunk_token_num=3000) # Max is 2048 + + def test_raptor_config_integration(self): + """Test raptor config integration""" + config = ParserConfig( + raptor=RaptorConfig(use_raptor=True, max_token=512) + ) + + assert config.raptor.use_raptor is True + assert config.raptor.max_token == 512 + + +class TestRaptorConfig: + """Test cases for RaptorConfig validation""" + + def test_default_raptor_config(self): + """Test default raptor configuration""" + config = RaptorConfig() + + assert config.use_raptor is False + assert config.max_token == 256 + assert config.threshold == 0.1 + + def test_custom_raptor_config(self): + """Test custom raptor configuration""" + config = RaptorConfig( + use_raptor=True, + max_token=512, + threshold=0.2 + ) + + assert config.use_raptor is True + assert config.max_token == 512 + assert config.threshold == 0.2 + + def test_threshold_range(self): + """Test threshold range validation""" + with pytest.raises(ValidationError): + RaptorConfig(threshold=1.5) # Max is 1.0 + + +class TestGraphragConfig: + """Test cases for GraphragConfig validation""" + + def test_default_graphrag_config(self): + """Test default graphrag configuration""" + config = GraphragConfig() + + assert config.use_graphrag is False + assert config.method == "light" + assert "organization" in config.entity_types + + def test_custom_graphrag_config(self): + """Test custom graphrag configuration""" + config = GraphragConfig( + use_graphrag=True, + method="general", + entity_types=["person", "location"] + ) + + assert config.use_graphrag is True + assert config.method == "general" + assert config.entity_types == ["person", "location"] + + def test_invalid_method(self): + """Test invalid method raises error""" + with pytest.raises(ValidationError): + GraphragConfig(method="invalid") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/conftest.py b/test/unit_test/conftest.py new file mode 100644 index 000000000..434588e49 --- /dev/null +++ b/test/unit_test/conftest.py @@ -0,0 +1,23 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +from pathlib import Path + +# Add project root to Python path +project_root = Path(__file__).parent.parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root))