Merge 5f17a2b6a9 into 7d23c3aed0
This commit is contained in:
commit
db6f29e529
6 changed files with 1629 additions and 0 deletions
15
test/unit_test/api/__init__.py
Normal file
15
test/unit_test/api/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
120
test/unit_test/api/test_common_utils.py
Normal file
120
test/unit_test/api/test_common_utils.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for api.utils.common module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from api.utils.common import string_to_bytes, bytes_to_string
|
||||
|
||||
|
||||
class TestStringToBytes:
|
||||
"""Test cases for string_to_bytes function"""
|
||||
|
||||
def test_string_input_returns_bytes(self):
|
||||
"""Test that string input is converted to bytes"""
|
||||
input_string = "hello world"
|
||||
result = string_to_bytes(input_string)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_bytes_input_returns_same_bytes(self):
|
||||
"""Test that bytes input is returned unchanged"""
|
||||
input_bytes = b"hello world"
|
||||
result = string_to_bytes(input_bytes)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert result == input_bytes
|
||||
assert result is input_bytes # Should be the same object
|
||||
|
||||
@pytest.mark.parametrize("input_val,expected", [
|
||||
("test", b"test"),
|
||||
("", b""),
|
||||
("123", b"123"),
|
||||
("Hello World", b"Hello World"),
|
||||
("Hello 世界 🌍", "Hello 世界 🌍".encode("utf-8")),
|
||||
("Hello, world! @#$%^&*()", b"Hello, world! @#$%^&*()"),
|
||||
("Newline\nTab\tQuote\"", b"Newline\nTab\tQuote\""),
|
||||
])
|
||||
def test_various_string_inputs(self, input_val, expected):
|
||||
"""Test conversion of various string inputs including unicode and special characters"""
|
||||
result = string_to_bytes(input_val)
|
||||
assert isinstance(result, bytes)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestBytesToString:
|
||||
"""Test cases for bytes_to_string function"""
|
||||
|
||||
@pytest.mark.parametrize("input_bytes,expected", [
|
||||
(b"hello world", "hello world"),
|
||||
(b"test", "test"),
|
||||
(b"", ""),
|
||||
(b"123", "123"),
|
||||
(b"Hello World", "Hello World"),
|
||||
("Hello 世界 🌍".encode("utf-8"), "Hello 世界 🌍"),
|
||||
(b"Special: @#$%^&*()", "Special: @#$%^&*()"),
|
||||
])
|
||||
def test_various_bytes_inputs(self, input_bytes, expected):
|
||||
"""Test conversion of various bytes inputs including unicode"""
|
||||
result = bytes_to_string(input_bytes)
|
||||
assert isinstance(result, str)
|
||||
assert result == expected
|
||||
|
||||
def test_invalid_utf8_raises_error(self):
|
||||
"""Test that invalid UTF-8 bytes raise an error"""
|
||||
# Invalid UTF-8 sequence
|
||||
invalid_bytes = b"\xff\xfe"
|
||||
|
||||
with pytest.raises(UnicodeDecodeError):
|
||||
bytes_to_string(invalid_bytes)
|
||||
|
||||
|
||||
class TestRoundtripConversion:
|
||||
"""Test roundtrip conversions between string and bytes"""
|
||||
|
||||
@pytest.mark.parametrize("test_string", [
|
||||
"Simple text",
|
||||
"Hello, World! 世界",
|
||||
"Unicode: 你好世界 🌍",
|
||||
"Special: !@#$%^&*()",
|
||||
"Multiline\nWith\tTabs",
|
||||
"",
|
||||
])
|
||||
def test_string_to_bytes_to_string(self, test_string):
|
||||
"""Test converting string to bytes and back for various inputs"""
|
||||
as_bytes = string_to_bytes(test_string)
|
||||
back_to_string = bytes_to_string(as_bytes)
|
||||
assert back_to_string == test_string
|
||||
|
||||
@pytest.mark.parametrize("test_bytes", [
|
||||
b"Simple text",
|
||||
b"Hello, World!",
|
||||
"Unicode: 你好世界 🌍".encode("utf-8"),
|
||||
b"Special: !@#$%^&*()",
|
||||
b"",
|
||||
])
|
||||
def test_bytes_to_string_to_bytes(self, test_bytes):
|
||||
"""Test converting bytes to string and back for various inputs"""
|
||||
as_string = bytes_to_string(test_bytes)
|
||||
back_to_bytes = string_to_bytes(as_string)
|
||||
assert back_to_bytes == test_bytes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
231
test/unit_test/api/test_configs.py
Normal file
231
test/unit_test/api/test_configs.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for api.utils.configs module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pickle
|
||||
import base64
|
||||
from unittest.mock import patch
|
||||
from api.utils.configs import (
|
||||
serialize_b64,
|
||||
deserialize_b64,
|
||||
RestrictedUnpickler,
|
||||
restricted_loads,
|
||||
safe_module
|
||||
)
|
||||
|
||||
|
||||
class TestSerializeB64:
|
||||
"""Test cases for serialize_b64 function"""
|
||||
|
||||
@pytest.mark.parametrize("test_data", [
|
||||
{"key": "value", "number": 42},
|
||||
[1, 2, 3, "test", {"nested": "dict"}],
|
||||
"Hello, World!",
|
||||
12345,
|
||||
{
|
||||
"list": [1, 2, 3],
|
||||
"dict": {"nested": {"deep": "value"}},
|
||||
"tuple": (1, 2, 3),
|
||||
"string": "test",
|
||||
"number": 42.5
|
||||
},
|
||||
None,
|
||||
{},
|
||||
[],
|
||||
])
|
||||
def test_serialize_returns_bytes(self, test_data):
|
||||
"""Test serialization of various data types returns bytes"""
|
||||
result = serialize_b64(test_data, to_str=False)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
# Should be valid base64
|
||||
decoded = base64.b64decode(result)
|
||||
assert isinstance(decoded, bytes)
|
||||
|
||||
def test_serialize_with_to_str_true(self):
|
||||
"""Test serialization with to_str=True returns string"""
|
||||
test_data = {"test": "data"}
|
||||
result = serialize_b64(test_data, to_str=True)
|
||||
|
||||
assert isinstance(result, str)
|
||||
# Should be valid base64 string
|
||||
base64.b64decode(result) # Should not raise
|
||||
|
||||
|
||||
class TestDeserializeB64:
|
||||
"""Test cases for deserialize_b64 function"""
|
||||
|
||||
@pytest.mark.parametrize("to_str", [True, False])
|
||||
def test_deserialize_string_and_bytes_input(self, to_str):
|
||||
"""Test deserialization with both string and bytes input"""
|
||||
test_data = {"key": "value"}
|
||||
serialized = serialize_b64(test_data, to_str=to_str)
|
||||
|
||||
result = deserialize_b64(serialized)
|
||||
|
||||
assert result == test_data
|
||||
|
||||
@patch('api.utils.configs.get_base_config')
|
||||
def test_deserialize_with_safe_module_disabled(self, mock_config):
|
||||
"""Test deserialization with safe module checking disabled"""
|
||||
mock_config.return_value = False
|
||||
|
||||
test_data = {"test": "data"}
|
||||
serialized = serialize_b64(test_data, to_str=True)
|
||||
|
||||
result = deserialize_b64(serialized)
|
||||
|
||||
assert result == test_data
|
||||
mock_config.assert_called_once_with('use_deserialize_safe_module', False)
|
||||
|
||||
@patch('api.utils.configs.get_base_config')
|
||||
def test_deserialize_with_safe_module_enabled(self, mock_config):
|
||||
"""Test deserialization with safe module checking enabled"""
|
||||
mock_config.return_value = True
|
||||
|
||||
# Simple data that doesn't require unsafe modules
|
||||
test_data = {"test": "data", "number": 42}
|
||||
serialized = serialize_b64(test_data, to_str=True)
|
||||
|
||||
result = deserialize_b64(serialized)
|
||||
|
||||
assert result == test_data
|
||||
|
||||
@pytest.mark.parametrize("test_data", [
|
||||
{"key": "value"},
|
||||
{
|
||||
"string": "test",
|
||||
"number": 123,
|
||||
"list": [1, 2, 3],
|
||||
"nested": {"key": "value"}
|
||||
},
|
||||
[1, 2, 3, 4, 5],
|
||||
"simple string",
|
||||
42,
|
||||
3.14,
|
||||
None,
|
||||
{"nested": {"deep": {"structure": "value"}}},
|
||||
])
|
||||
def test_roundtrip_various_data_types(self, test_data):
|
||||
"""Test roundtrip serialization and deserialization for various data types"""
|
||||
serialized = serialize_b64(test_data)
|
||||
deserialized = deserialize_b64(serialized)
|
||||
|
||||
assert deserialized == test_data
|
||||
|
||||
|
||||
class TestRestrictedUnpickler:
|
||||
"""Test cases for RestrictedUnpickler class"""
|
||||
|
||||
@pytest.mark.parametrize("test_data", [
|
||||
{"test": "data"},
|
||||
[1, 2, 3, "test", {"key": "value"}],
|
||||
{"nested": {"deep": "structure"}},
|
||||
[1, 2, 3],
|
||||
"simple string",
|
||||
])
|
||||
def test_restricted_loads_with_safe_data(self, test_data):
|
||||
"""Test restricted_loads with various safe data types"""
|
||||
pickled = pickle.dumps(test_data)
|
||||
|
||||
result = restricted_loads(pickled)
|
||||
|
||||
assert result == test_data
|
||||
|
||||
@patch('api.utils.configs.get_base_config')
|
||||
def test_blocks_unsafe_modules(self, mock_config):
|
||||
"""Test that unsafe modules are blocked"""
|
||||
mock_config.return_value = True
|
||||
|
||||
# Try to pickle something from an unsafe module
|
||||
# We'll simulate this by creating a pickle that references os module
|
||||
class UnsafeClass:
|
||||
__module__ = 'os.path'
|
||||
def __reduce__(self):
|
||||
return (eval, ("1+1",))
|
||||
|
||||
try:
|
||||
unsafe_obj = UnsafeClass()
|
||||
pickled = pickle.dumps(unsafe_obj)
|
||||
|
||||
with pytest.raises(pickle.UnpicklingError):
|
||||
restricted_loads(pickled)
|
||||
except:
|
||||
# If we can't create the unsafe pickle, skip this test
|
||||
pytest.skip("Unable to create unsafe pickle for testing")
|
||||
|
||||
def test_safe_module_set_contains_expected_modules(self):
|
||||
"""Test that safe_module set contains expected modules"""
|
||||
assert 'numpy' in safe_module
|
||||
assert 'rag_flow' in safe_module
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration tests for serialization/deserialization workflows"""
|
||||
|
||||
@pytest.mark.parametrize("to_str,original_data", [
|
||||
(True, {
|
||||
"user": "test_user",
|
||||
"settings": {
|
||||
"theme": "dark",
|
||||
"notifications": True
|
||||
},
|
||||
"items": [1, 2, 3, 4, 5]
|
||||
}),
|
||||
(False, {"test": "data", "number": 42}),
|
||||
(True, {}),
|
||||
(True, {
|
||||
f"key_{i}": {
|
||||
"value": i,
|
||||
"list": list(range(10)),
|
||||
"nested": {"deep": f"value_{i}"}
|
||||
}
|
||||
for i in range(100)
|
||||
}),
|
||||
])
|
||||
def test_serialize_deserialize_workflow(self, to_str, original_data):
|
||||
"""Test complete workflow of serialize and deserialize with various data"""
|
||||
# Serialize
|
||||
serialized = serialize_b64(original_data, to_str=to_str)
|
||||
|
||||
if to_str:
|
||||
assert isinstance(serialized, str)
|
||||
else:
|
||||
assert isinstance(serialized, bytes)
|
||||
|
||||
# Deserialize back
|
||||
deserialized = deserialize_b64(serialized)
|
||||
assert deserialized == original_data
|
||||
|
||||
@patch('api.utils.configs.get_base_config')
|
||||
def test_safe_deserialization_workflow(self, mock_config):
|
||||
"""Test safe deserialization workflow"""
|
||||
mock_config.return_value = True
|
||||
|
||||
test_data = {"safe": "data"}
|
||||
serialized = serialize_b64(test_data, to_str=True)
|
||||
|
||||
result = deserialize_b64(serialized)
|
||||
|
||||
assert result == test_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
452
test/unit_test/api/test_json_encode.py
Normal file
452
test/unit_test/api/test_json_encode.py
Normal file
|
|
@ -0,0 +1,452 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for api.utils.json_encode module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import datetime
|
||||
from enum import Enum, IntEnum
|
||||
from api.utils.json_encode import (
|
||||
BaseType,
|
||||
CustomJSONEncoder,
|
||||
json_dumps,
|
||||
json_loads
|
||||
)
|
||||
|
||||
|
||||
class TestBaseTypeToDict:
|
||||
"""Test cases for BaseType.to_dict method"""
|
||||
|
||||
def test_simple_object_to_dict(self):
|
||||
"""Test converting simple object to dictionary"""
|
||||
class SimpleType(BaseType):
|
||||
def __init__(self):
|
||||
self.name = "test"
|
||||
self.value = 42
|
||||
|
||||
obj = SimpleType()
|
||||
result = obj.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result == {"name": "test", "value": 42}
|
||||
|
||||
def test_private_attributes_excluded(self):
|
||||
"""Test that private attributes (starting with _) are stripped"""
|
||||
class PrivateType(BaseType):
|
||||
def __init__(self):
|
||||
self._private = "hidden"
|
||||
self._another = "also hidden"
|
||||
self.public = "visible"
|
||||
|
||||
obj = PrivateType()
|
||||
result = obj.to_dict()
|
||||
|
||||
# Private attributes should have _ stripped in keys
|
||||
assert "private" in result
|
||||
assert "another" in result
|
||||
assert "public" in result
|
||||
assert result["private"] == "hidden"
|
||||
assert result["another"] == "also hidden"
|
||||
assert result["public"] == "visible"
|
||||
|
||||
def test_empty_object(self):
|
||||
"""Test converting empty object to dictionary"""
|
||||
class EmptyType(BaseType):
|
||||
pass
|
||||
|
||||
obj = EmptyType()
|
||||
result = obj.to_dict()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_nested_object_to_dict(self):
|
||||
"""Test converting object with nested values"""
|
||||
class NestedType(BaseType):
|
||||
def __init__(self):
|
||||
self.data = {"key": "value"}
|
||||
self.items = [1, 2, 3]
|
||||
self.text = "test"
|
||||
|
||||
obj = NestedType()
|
||||
result = obj.to_dict()
|
||||
|
||||
assert result["data"] == {"key": "value"}
|
||||
assert result["items"] == [1, 2, 3]
|
||||
assert result["text"] == "test"
|
||||
|
||||
|
||||
class TestBaseTypeToDictWithType:
|
||||
"""Test cases for BaseType.to_dict_with_type method"""
|
||||
|
||||
def test_includes_type_information(self):
|
||||
"""Test that type information is included"""
|
||||
class TypedObject(BaseType):
|
||||
def __init__(self):
|
||||
self.value = "test"
|
||||
|
||||
obj = TypedObject()
|
||||
result = obj.to_dict_with_type()
|
||||
|
||||
assert "type" in result
|
||||
assert "data" in result
|
||||
assert result["type"] == "TypedObject"
|
||||
|
||||
def test_includes_module_information(self):
|
||||
"""Test that module information is included"""
|
||||
class ModuleObject(BaseType):
|
||||
def __init__(self):
|
||||
self.value = "test"
|
||||
|
||||
obj = ModuleObject()
|
||||
result = obj.to_dict_with_type()
|
||||
|
||||
assert "module" in result
|
||||
assert result["module"] is not None
|
||||
|
||||
def test_nested_objects_with_types(self):
|
||||
"""Test nested BaseType objects include type info"""
|
||||
class InnerType(BaseType):
|
||||
def __init__(self):
|
||||
self.inner_value = "inner"
|
||||
|
||||
class OuterType(BaseType):
|
||||
def __init__(self):
|
||||
self.nested = InnerType()
|
||||
self.value = "outer"
|
||||
|
||||
obj = OuterType()
|
||||
result = obj.to_dict_with_type()
|
||||
|
||||
assert result["type"] == "OuterType"
|
||||
assert "data" in result
|
||||
assert "nested" in result["data"]
|
||||
assert result["data"]["nested"]["type"] == "InnerType"
|
||||
|
||||
def test_list_handling(self):
|
||||
"""Test handling of lists in to_dict_with_type"""
|
||||
class ListType(BaseType):
|
||||
def __init__(self):
|
||||
self.items = [1, 2, 3]
|
||||
|
||||
obj = ListType()
|
||||
result = obj.to_dict_with_type()
|
||||
|
||||
assert "items" in result["data"]
|
||||
items_data = result["data"]["items"]
|
||||
assert items_data["type"] == "list"
|
||||
assert isinstance(items_data["data"], list)
|
||||
|
||||
def test_dict_handling(self):
|
||||
"""Test handling of dictionaries in to_dict_with_type"""
|
||||
class DictType(BaseType):
|
||||
def __init__(self):
|
||||
self.config = {"key": "value"}
|
||||
|
||||
obj = DictType()
|
||||
result = obj.to_dict_with_type()
|
||||
|
||||
assert "config" in result["data"]
|
||||
config_data = result["data"]["config"]
|
||||
assert config_data["type"] == "dict"
|
||||
|
||||
|
||||
class TestCustomJSONEncoder:
|
||||
"""Test cases for CustomJSONEncoder class"""
|
||||
|
||||
def test_encode_datetime(self):
|
||||
"""Test encoding of datetime objects"""
|
||||
dt = datetime.datetime(2025, 12, 3, 14, 30, 45)
|
||||
result = json.dumps(dt, cls=CustomJSONEncoder)
|
||||
|
||||
assert result == '"2025-12-03 14:30:45"'
|
||||
|
||||
def test_encode_date(self):
|
||||
"""Test encoding of date objects"""
|
||||
d = datetime.date(2025, 12, 3)
|
||||
result = json.dumps(d, cls=CustomJSONEncoder)
|
||||
|
||||
assert result == '"2025-12-03"'
|
||||
|
||||
def test_encode_timedelta(self):
|
||||
"""Test encoding of timedelta objects"""
|
||||
td = datetime.timedelta(days=1, hours=2, minutes=30)
|
||||
result = json.dumps(td, cls=CustomJSONEncoder)
|
||||
|
||||
assert isinstance(json.loads(result), str)
|
||||
assert "1 day" in json.loads(result)
|
||||
|
||||
def test_encode_enum(self):
|
||||
"""Test encoding of Enum objects"""
|
||||
class Color(Enum):
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
BLUE = "blue"
|
||||
|
||||
result = json.dumps(Color.RED, cls=CustomJSONEncoder)
|
||||
|
||||
assert result == '"red"'
|
||||
|
||||
def test_encode_int_enum(self):
|
||||
"""Test encoding of IntEnum objects"""
|
||||
class Priority(IntEnum):
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
|
||||
result = json.dumps(Priority.HIGH, cls=CustomJSONEncoder)
|
||||
|
||||
assert result == '3'
|
||||
|
||||
def test_encode_set(self):
|
||||
"""Test encoding of set objects"""
|
||||
test_set = {1, 2, 3, 4, 5}
|
||||
result = json.dumps(test_set, cls=CustomJSONEncoder)
|
||||
|
||||
# Set should be converted to list
|
||||
decoded = json.loads(result)
|
||||
assert isinstance(decoded, list)
|
||||
assert set(decoded) == test_set
|
||||
|
||||
def test_encode_basetype_object(self):
|
||||
"""Test encoding of BaseType objects"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.name = "test"
|
||||
self.value = 42
|
||||
|
||||
obj = TestType()
|
||||
result = json.dumps(obj, cls=CustomJSONEncoder)
|
||||
|
||||
decoded = json.loads(result)
|
||||
assert decoded["name"] == "test"
|
||||
assert decoded["value"] == 42
|
||||
|
||||
def test_encode_basetype_with_type_flag(self):
|
||||
"""Test encoding BaseType with with_type flag"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.value = "test"
|
||||
|
||||
obj = TestType()
|
||||
encoder = CustomJSONEncoder(with_type=True)
|
||||
result = json.dumps(obj, cls=CustomJSONEncoder, with_type=True)
|
||||
|
||||
decoded = json.loads(result)
|
||||
assert "type" in decoded
|
||||
assert "data" in decoded
|
||||
|
||||
def test_encode_type_object(self):
|
||||
"""Test encoding of type objects"""
|
||||
result = json.dumps(str, cls=CustomJSONEncoder)
|
||||
|
||||
assert result == '"str"'
|
||||
|
||||
def test_encode_nested_structures(self):
|
||||
"""Test encoding of nested structures with various types"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.name = "test"
|
||||
|
||||
data = {
|
||||
"datetime": datetime.datetime(2025, 12, 3, 12, 0, 0),
|
||||
"date": datetime.date(2025, 12, 3),
|
||||
"set": {1, 2, 3},
|
||||
"object": TestType(),
|
||||
"list": [1, 2, 3]
|
||||
}
|
||||
|
||||
result = json.dumps(data, cls=CustomJSONEncoder)
|
||||
decoded = json.loads(result)
|
||||
|
||||
assert decoded["datetime"] == "2025-12-03 12:00:00"
|
||||
assert decoded["date"] == "2025-12-03"
|
||||
assert isinstance(decoded["set"], list)
|
||||
assert decoded["object"]["name"] == "test"
|
||||
|
||||
|
||||
class TestJsonDumps:
|
||||
"""Test cases for json_dumps function"""
|
||||
|
||||
def test_json_dumps_basic(self):
|
||||
"""Test basic json_dumps functionality"""
|
||||
data = {"key": "value", "number": 42}
|
||||
result = json_dumps(data)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert json.loads(result) == data
|
||||
|
||||
def test_json_dumps_with_byte_false(self):
|
||||
"""Test json_dumps with byte=False returns string"""
|
||||
data = {"test": "data"}
|
||||
result = json_dumps(data, byte=False)
|
||||
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_json_dumps_with_byte_true(self):
|
||||
"""Test json_dumps with byte=True returns bytes"""
|
||||
data = {"test": "data"}
|
||||
result = json_dumps(data, byte=True)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
|
||||
def test_json_dumps_with_indent(self):
|
||||
"""Test json_dumps with indentation"""
|
||||
data = {"key": "value"}
|
||||
result = json_dumps(data, indent=2)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "\n" in result # Indented JSON has newlines
|
||||
|
||||
def test_json_dumps_with_type_false(self):
|
||||
"""Test json_dumps with with_type=False"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.value = "test"
|
||||
|
||||
obj = TestType()
|
||||
result = json_dumps(obj, with_type=False)
|
||||
|
||||
decoded = json.loads(result)
|
||||
assert "type" not in decoded
|
||||
assert decoded["value"] == "test"
|
||||
|
||||
def test_json_dumps_with_type_true(self):
|
||||
"""Test json_dumps with with_type=True"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.value = "test"
|
||||
|
||||
obj = TestType()
|
||||
result = json_dumps(obj, with_type=True)
|
||||
|
||||
decoded = json.loads(result)
|
||||
assert "type" in decoded
|
||||
assert "data" in decoded
|
||||
|
||||
def test_json_dumps_datetime(self):
|
||||
"""Test json_dumps with datetime objects"""
|
||||
data = {
|
||||
"timestamp": datetime.datetime(2025, 12, 3, 15, 30, 0)
|
||||
}
|
||||
result = json_dumps(data)
|
||||
|
||||
decoded = json.loads(result)
|
||||
assert decoded["timestamp"] == "2025-12-03 15:30:00"
|
||||
|
||||
|
||||
class TestJsonLoads:
|
||||
"""Test cases for json_loads function"""
|
||||
|
||||
def test_json_loads_string_input(self):
|
||||
"""Test json_loads with string input"""
|
||||
json_string = '{"key": "value", "number": 42}'
|
||||
result = json_loads(json_string)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["key"] == "value"
|
||||
assert result["number"] == 42
|
||||
|
||||
def test_json_loads_bytes_input(self):
|
||||
"""Test json_loads with bytes input"""
|
||||
json_bytes = b'{"key": "value"}'
|
||||
result = json_loads(json_bytes)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["key"] == "value"
|
||||
|
||||
def test_json_loads_with_object_hook(self):
|
||||
"""Test json_loads with object_hook parameter"""
|
||||
def custom_hook(obj):
|
||||
if "special" in obj:
|
||||
obj["processed"] = True
|
||||
return obj
|
||||
|
||||
json_string = '{"special": "value"}'
|
||||
result = json_loads(json_string, object_hook=custom_hook)
|
||||
|
||||
assert result["processed"] is True
|
||||
|
||||
def test_json_loads_empty_object(self):
|
||||
"""Test json_loads with empty object"""
|
||||
result = json_loads('{}')
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_json_loads_array(self):
|
||||
"""Test json_loads with array"""
|
||||
result = json_loads('[1, 2, 3, 4, 5]')
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
class TestRoundtripConversion:
|
||||
"""Test roundtrip conversions between dumps and loads"""
|
||||
|
||||
def test_roundtrip_dict(self):
|
||||
"""Test roundtrip conversion of dictionary"""
|
||||
original = {"key": "value", "number": 42, "list": [1, 2, 3]}
|
||||
|
||||
dumped = json_dumps(original)
|
||||
loaded = json_loads(dumped)
|
||||
|
||||
assert loaded == original
|
||||
|
||||
def test_roundtrip_with_bytes(self):
|
||||
"""Test roundtrip with byte conversion"""
|
||||
original = {"test": "data"}
|
||||
|
||||
dumped = json_dumps(original, byte=True)
|
||||
loaded = json_loads(dumped)
|
||||
|
||||
assert loaded == original
|
||||
|
||||
def test_roundtrip_basetype(self):
|
||||
"""Test roundtrip with BaseType object"""
|
||||
class TestType(BaseType):
|
||||
def __init__(self):
|
||||
self.name = "test"
|
||||
self.value = 42
|
||||
|
||||
obj = TestType()
|
||||
|
||||
dumped = json_dumps(obj)
|
||||
loaded = json_loads(dumped)
|
||||
|
||||
assert loaded["name"] == "test"
|
||||
assert loaded["value"] == 42
|
||||
|
||||
@pytest.mark.parametrize("test_data", [
|
||||
{"simple": "dict"},
|
||||
[1, 2, 3, 4, 5],
|
||||
{"nested": {"deep": {"structure": "value"}}},
|
||||
{"mixed": [1, "two", 3.0, {"four": 4}]},
|
||||
])
|
||||
def test_roundtrip_various_structures(self, test_data):
|
||||
"""Test roundtrip for various data structures"""
|
||||
dumped = json_dumps(test_data)
|
||||
loaded = json_loads(dumped)
|
||||
|
||||
assert loaded == test_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
788
test/unit_test/api/test_validation_utils.py
Normal file
788
test/unit_test/api/test_validation_utils.py
Normal file
|
|
@ -0,0 +1,788 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Unit tests for api.utils.validation_utils module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from uuid import UUID, uuid1
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
|
||||
|
||||
from api.utils.validation_utils import (
|
||||
validate_and_parse_json_request,
|
||||
validate_and_parse_request_args,
|
||||
format_validation_error_message,
|
||||
normalize_str,
|
||||
validate_uuid1_hex,
|
||||
CreateDatasetReq,
|
||||
UpdateDatasetReq,
|
||||
DeleteDatasetReq,
|
||||
ListDatasetReq,
|
||||
ParserConfig,
|
||||
RaptorConfig,
|
||||
GraphragConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeStr:
|
||||
"""Test cases for normalize_str function"""
|
||||
|
||||
def test_normalize_string_with_whitespace(self):
|
||||
"""Test normalization of string with leading/trailing whitespace"""
|
||||
result = normalize_str(" Admin ")
|
||||
assert result == "admin"
|
||||
|
||||
def test_normalize_string_uppercase(self):
|
||||
"""Test normalization converts to lowercase"""
|
||||
result = normalize_str("UPPERCASE")
|
||||
assert result == "uppercase"
|
||||
|
||||
def test_normalize_mixed_case(self):
|
||||
"""Test normalization of mixed case string"""
|
||||
result = normalize_str(" MiXeD CaSe ")
|
||||
assert result == "mixed case"
|
||||
|
||||
def test_normalize_empty_string(self):
|
||||
"""Test normalization of empty string"""
|
||||
result = normalize_str("")
|
||||
assert result == ""
|
||||
|
||||
def test_normalize_whitespace_only(self):
|
||||
"""Test normalization of whitespace-only string"""
|
||||
result = normalize_str(" ")
|
||||
assert result == ""
|
||||
|
||||
def test_preserve_non_string_integer(self):
|
||||
"""Test that integers are preserved"""
|
||||
result = normalize_str(42)
|
||||
assert result == 42
|
||||
|
||||
def test_preserve_non_string_none(self):
|
||||
"""Test that None is preserved"""
|
||||
result = normalize_str(None)
|
||||
assert result is None
|
||||
|
||||
def test_preserve_non_string_list(self):
|
||||
"""Test that lists are preserved"""
|
||||
input_list = ["User", "Admin"]
|
||||
result = normalize_str(input_list)
|
||||
assert result == input_list
|
||||
|
||||
def test_preserve_non_string_dict(self):
|
||||
"""Test that dicts are preserved"""
|
||||
input_dict = {"role": "Admin"}
|
||||
result = normalize_str(input_dict)
|
||||
assert result == input_dict
|
||||
|
||||
@pytest.mark.parametrize("input_val,expected", [
|
||||
("ReadOnly", "readonly"),
|
||||
(" ADMIN ", "admin"),
|
||||
("User", "user"),
|
||||
("", ""),
|
||||
(123, 123),
|
||||
(False, False),
|
||||
(0, 0),
|
||||
])
|
||||
def test_various_inputs(self, input_val, expected):
|
||||
"""Test various input types"""
|
||||
result = normalize_str(input_val)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestValidateUuid1Hex:
|
||||
"""Test cases for validate_uuid1_hex function"""
|
||||
|
||||
def test_valid_uuid1_string(self):
|
||||
"""Test validation of valid UUID1 string"""
|
||||
uuid1_obj = uuid1()
|
||||
uuid1_str = str(uuid1_obj)
|
||||
|
||||
result = validate_uuid1_hex(uuid1_str)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 32
|
||||
assert result == uuid1_obj.hex
|
||||
|
||||
def test_valid_uuid1_object(self):
|
||||
"""Test validation of valid UUID1 object"""
|
||||
uuid1_obj = uuid1()
|
||||
|
||||
result = validate_uuid1_hex(uuid1_obj)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == uuid1_obj.hex
|
||||
|
||||
def test_uuid1_hex_no_hyphens(self):
|
||||
"""Test that result has no hyphens"""
|
||||
uuid1_obj = uuid1()
|
||||
result = validate_uuid1_hex(uuid1_obj)
|
||||
|
||||
assert "-" not in result
|
||||
|
||||
def test_invalid_uuid_string_raises_error(self):
|
||||
"""Test that invalid UUID string raises error"""
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
with pytest.raises(PydanticCustomError) as exc_info:
|
||||
validate_uuid1_hex("not-a-uuid")
|
||||
|
||||
assert exc_info.value.type == "invalid_UUID1_format"
|
||||
|
||||
def test_non_uuid1_version_raises_error(self):
|
||||
"""Test that non-UUID1 version raises error"""
|
||||
from pydantic_core import PydanticCustomError
|
||||
from uuid import uuid4
|
||||
|
||||
uuid4_obj = uuid4()
|
||||
|
||||
with pytest.raises(PydanticCustomError) as exc_info:
|
||||
validate_uuid1_hex(uuid4_obj)
|
||||
|
||||
assert exc_info.value.type == "invalid_UUID1_format"
|
||||
|
||||
def test_integer_input_raises_error(self):
|
||||
"""Test that integer input raises error"""
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
with pytest.raises(PydanticCustomError):
|
||||
validate_uuid1_hex(12345)
|
||||
|
||||
def test_none_input_raises_error(self):
|
||||
"""Test that None input raises error"""
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
with pytest.raises(PydanticCustomError):
|
||||
validate_uuid1_hex(None)
|
||||
|
||||
|
||||
class TestFormatValidationErrorMessage:
|
||||
"""Test cases for format_validation_error_message function"""
|
||||
|
||||
def test_single_validation_error(self):
|
||||
"""Test formatting of single validation error"""
|
||||
class TestModel(BaseModel):
|
||||
name: str
|
||||
|
||||
try:
|
||||
TestModel(name=123)
|
||||
except ValidationError as e:
|
||||
result = format_validation_error_message(e)
|
||||
|
||||
assert "Field: <name>" in result
|
||||
assert "Message:" in result
|
||||
assert "Value: <123>" in result
|
||||
|
||||
def test_multiple_validation_errors(self):
|
||||
"""Test formatting of multiple validation errors"""
|
||||
class TestModel(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
try:
|
||||
TestModel(name=123, age="not_an_int")
|
||||
except ValidationError as e:
|
||||
result = format_validation_error_message(e)
|
||||
|
||||
assert "Field: <name>" in result
|
||||
assert "Field: <age>" in result
|
||||
assert "\n" in result # Multiple errors separated by newlines
|
||||
|
||||
def test_long_value_truncation(self):
|
||||
"""Test that long values are truncated"""
|
||||
class TestModel(BaseModel):
|
||||
text: str
|
||||
|
||||
long_value = "x" * 200
|
||||
|
||||
try:
|
||||
TestModel(text=123) # Wrong type to trigger error
|
||||
except ValidationError as e:
|
||||
# Manually create error with long value
|
||||
pass
|
||||
|
||||
# Create a model with max_length constraint
|
||||
class TestModel2(BaseModel):
|
||||
text: str = Field(max_length=10)
|
||||
|
||||
try:
|
||||
TestModel2(text="x" * 200)
|
||||
except ValidationError as e:
|
||||
result = format_validation_error_message(e)
|
||||
|
||||
# Check that value is truncated
|
||||
assert "..." in result or len(result) < 500
|
||||
|
||||
def test_nested_field_error(self):
|
||||
"""Test formatting of nested field errors"""
|
||||
class NestedModel(BaseModel):
|
||||
value: int
|
||||
|
||||
class ParentModel(BaseModel):
|
||||
nested: NestedModel
|
||||
|
||||
try:
|
||||
ParentModel(nested={"value": "not_an_int"})
|
||||
except ValidationError as e:
|
||||
result = format_validation_error_message(e)
|
||||
|
||||
assert "nested.value" in result
|
||||
|
||||
|
||||
class TestValidateAndParseRequestArgs:
|
||||
"""Test cases for validate_and_parse_request_args function"""
|
||||
|
||||
def test_valid_request_args(self):
|
||||
"""Test validation of valid request arguments"""
|
||||
class TestValidator(BaseModel):
|
||||
param1: str
|
||||
param2: int = 10
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.args.to_dict.return_value = {"param1": "value", "param2": "20"}
|
||||
|
||||
result, error = validate_and_parse_request_args(mock_request, TestValidator)
|
||||
|
||||
assert error is None
|
||||
assert result is not None
|
||||
assert result["param1"] == "value"
|
||||
assert result["param2"] == 20
|
||||
|
||||
def test_missing_required_field(self):
|
||||
"""Test validation with missing required field"""
|
||||
class TestValidator(BaseModel):
|
||||
required_field: str
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.args.to_dict.return_value = {}
|
||||
|
||||
result, error = validate_and_parse_request_args(mock_request, TestValidator)
|
||||
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "required_field" in error
|
||||
|
||||
def test_with_extras_parameter(self):
|
||||
"""Test validation with extras parameter"""
|
||||
class TestValidator(BaseModel):
|
||||
param1: str
|
||||
internal_id: int
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.args.to_dict.return_value = {"param1": "value"}
|
||||
|
||||
result, error = validate_and_parse_request_args(
|
||||
mock_request,
|
||||
TestValidator,
|
||||
extras={"internal_id": 123}
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert result is not None
|
||||
assert result["param1"] == "value"
|
||||
assert "internal_id" not in result # Extras should be removed
|
||||
|
||||
def test_type_conversion(self):
|
||||
"""Test that Pydantic performs type conversion"""
|
||||
class TestValidator(BaseModel):
|
||||
number: int
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.args.to_dict.return_value = {"number": "42"}
|
||||
|
||||
result, error = validate_and_parse_request_args(mock_request, TestValidator)
|
||||
|
||||
assert error is None
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["number"], int)
|
||||
|
||||
|
||||
class TestValidateAndParseJsonRequest:
|
||||
"""Test cases for validate_and_parse_json_request function"""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_json_request(self):
|
||||
"""Test validation of valid JSON request"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(return_value={"name": "test", "value": 42})
|
||||
|
||||
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
|
||||
|
||||
assert error is None
|
||||
assert result is not None
|
||||
assert result["name"] == "test"
|
||||
assert result["value"] == 42
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_unsupported_content_type(self):
|
||||
"""Test handling of unsupported content type"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(side_effect=UnsupportedMediaType())
|
||||
mock_request.content_type = "text/xml"
|
||||
|
||||
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
|
||||
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "Unsupported content type" in error
|
||||
assert "text/xml" in error
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_malformed_json(self):
|
||||
"""Test handling of malformed JSON"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(side_effect=BadRequest())
|
||||
|
||||
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
|
||||
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "Malformed JSON syntax" in error
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_payload_type(self):
|
||||
"""Test handling of non-dict payload"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(return_value=["not", "a", "dict"])
|
||||
|
||||
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
|
||||
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "Invalid request payload" in error
|
||||
assert "list" in error
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_validation_error(self):
|
||||
"""Test handling of Pydantic validation errors"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(return_value={"name": 123, "age": "not_int"})
|
||||
|
||||
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
|
||||
|
||||
assert result is None
|
||||
assert error is not None
|
||||
assert "Field:" in error
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_with_extras_parameter(self):
|
||||
"""Test validation with extras parameter"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
user_id: str
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(return_value={"name": "test"})
|
||||
|
||||
result, error = await validate_and_parse_json_request(
|
||||
mock_request,
|
||||
TestValidator,
|
||||
extras={"user_id": "user_123"}
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert result is not None
|
||||
assert result["name"] == "test"
|
||||
assert "user_id" not in result # Extras should be removed
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_exclude_unset_parameter(self):
|
||||
"""Test exclude_unset parameter"""
|
||||
class TestValidator(BaseModel):
|
||||
name: str
|
||||
optional: str = "default"
|
||||
|
||||
mock_request = AsyncMock()
|
||||
mock_request.get_json = AsyncMock(return_value={"name": "test"})
|
||||
|
||||
result, error = await validate_and_parse_json_request(
|
||||
mock_request,
|
||||
TestValidator,
|
||||
exclude_unset=True
|
||||
)
|
||||
|
||||
assert error is None
|
||||
assert result is not None
|
||||
assert "name" in result
|
||||
assert "optional" not in result # Not set, should be excluded
|
||||
|
||||
|
||||
class TestCreateDatasetReq:
|
||||
"""Test cases for CreateDatasetReq validation"""
|
||||
|
||||
def test_valid_dataset_creation(self):
|
||||
"""Test valid dataset creation request"""
|
||||
data = {
|
||||
"name": "Test Dataset",
|
||||
"embedding_model": "text-embedding-3-large@openai"
|
||||
}
|
||||
|
||||
dataset = CreateDatasetReq(**data)
|
||||
|
||||
assert dataset.name == "Test Dataset"
|
||||
assert dataset.embedding_model == "text-embedding-3-large@openai"
|
||||
|
||||
def test_name_whitespace_stripping(self):
|
||||
"""Test that name whitespace is stripped"""
|
||||
data = {
|
||||
"name": " Test Dataset "
|
||||
}
|
||||
|
||||
dataset = CreateDatasetReq(**data)
|
||||
|
||||
assert dataset.name == "Test Dataset"
|
||||
|
||||
def test_empty_name_raises_error(self):
|
||||
"""Test that empty name raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(name="")
|
||||
|
||||
def test_invalid_embedding_model_format(self):
|
||||
"""Test that invalid embedding model format raises error"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
CreateDatasetReq(
|
||||
name="Test",
|
||||
embedding_model="invalid_model"
|
||||
)
|
||||
|
||||
assert "format_invalid" in str(exc_info.value)
|
||||
|
||||
def test_embedding_model_without_provider(self):
|
||||
"""Test embedding model without provider raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(
|
||||
name="Test",
|
||||
embedding_model="model_name@"
|
||||
)
|
||||
|
||||
def test_valid_avatar_base64(self):
|
||||
"""Test valid base64 avatar"""
|
||||
data = {
|
||||
"name": "Test",
|
||||
"avatar": ""
|
||||
}
|
||||
|
||||
dataset = CreateDatasetReq(**data)
|
||||
|
||||
assert dataset.avatar.startswith(""
|
||||
)
|
||||
|
||||
def test_avatar_missing_data_prefix(self):
|
||||
"""Test avatar missing data: prefix raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(
|
||||
name="Test",
|
||||
avatar="image/png;base64,abc123"
|
||||
)
|
||||
|
||||
def test_default_chunk_method(self):
|
||||
"""Test default chunk_method is set to 'naive'"""
|
||||
dataset = CreateDatasetReq(name="Test")
|
||||
|
||||
assert dataset.chunk_method == "naive"
|
||||
|
||||
def test_valid_chunk_methods(self):
|
||||
"""Test various valid chunk methods"""
|
||||
valid_methods = ["naive", "book", "email", "laws", "manual", "one",
|
||||
"paper", "picture", "presentation", "qa", "table", "tag"]
|
||||
|
||||
for method in valid_methods:
|
||||
dataset = CreateDatasetReq(name="Test", chunk_method=method)
|
||||
assert dataset.chunk_method == method
|
||||
|
||||
def test_invalid_chunk_method(self):
|
||||
"""Test invalid chunk method raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(name="Test", chunk_method="invalid_method")
|
||||
|
||||
def test_pipeline_id_validation(self):
|
||||
"""Test pipeline_id validation"""
|
||||
# Valid 32-char hex
|
||||
dataset = CreateDatasetReq(
|
||||
name="Test",
|
||||
parse_type=1,
|
||||
pipeline_id="a" * 32
|
||||
)
|
||||
assert dataset.pipeline_id == "a" * 32
|
||||
|
||||
def test_pipeline_id_wrong_length(self):
|
||||
"""Test pipeline_id with wrong length raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(
|
||||
name="Test",
|
||||
parse_type=1,
|
||||
pipeline_id="abc" # Too short
|
||||
)
|
||||
|
||||
def test_pipeline_id_non_hex(self):
|
||||
"""Test pipeline_id with non-hex characters raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
CreateDatasetReq(
|
||||
name="Test",
|
||||
parse_type=1,
|
||||
pipeline_id="g" * 32 # 'g' is not hex
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateDatasetReq:
|
||||
"""Test cases for UpdateDatasetReq validation"""
|
||||
|
||||
def test_valid_update_request(self):
|
||||
"""Test valid dataset update request"""
|
||||
uuid1_obj = uuid1()
|
||||
|
||||
data = {
|
||||
"dataset_id": str(uuid1_obj),
|
||||
"name": "Updated Dataset"
|
||||
}
|
||||
|
||||
dataset = UpdateDatasetReq(**data)
|
||||
|
||||
assert dataset.dataset_id == uuid1_obj.hex
|
||||
assert dataset.name == "Updated Dataset"
|
||||
|
||||
def test_dataset_id_uuid1_validation(self):
|
||||
"""Test that dataset_id must be UUID1"""
|
||||
from uuid import uuid4
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UpdateDatasetReq(
|
||||
dataset_id=str(uuid4()),
|
||||
name="Test"
|
||||
)
|
||||
|
||||
def test_pagerank_validation(self):
|
||||
"""Test pagerank field validation"""
|
||||
uuid1_obj = uuid1()
|
||||
|
||||
dataset = UpdateDatasetReq(
|
||||
dataset_id=str(uuid1_obj),
|
||||
name="Test",
|
||||
pagerank=50
|
||||
)
|
||||
|
||||
assert dataset.pagerank == 50
|
||||
|
||||
def test_pagerank_out_of_range(self):
|
||||
"""Test pagerank out of range raises error"""
|
||||
uuid1_obj = uuid1()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UpdateDatasetReq(
|
||||
dataset_id=str(uuid1_obj),
|
||||
name="Test",
|
||||
pagerank=101 # Max is 100
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteDatasetReq:
|
||||
"""Test cases for DeleteDatasetReq validation"""
|
||||
|
||||
def test_valid_delete_request(self):
|
||||
"""Test valid delete request"""
|
||||
uuid1_obj1 = uuid1()
|
||||
uuid1_obj2 = uuid1()
|
||||
|
||||
req = DeleteDatasetReq(ids=[str(uuid1_obj1), str(uuid1_obj2)])
|
||||
|
||||
assert len(req.ids) == 2
|
||||
assert uuid1_obj1.hex in req.ids
|
||||
assert uuid1_obj2.hex in req.ids
|
||||
|
||||
def test_duplicate_ids_raises_error(self):
|
||||
"""Test that duplicate IDs raise error"""
|
||||
uuid1_obj = uuid1()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
DeleteDatasetReq(ids=[str(uuid1_obj), str(uuid1_obj)])
|
||||
|
||||
assert "duplicate" in str(exc_info.value).lower()
|
||||
|
||||
def test_empty_ids_list(self):
|
||||
"""Test empty IDs list"""
|
||||
req = DeleteDatasetReq(ids=[])
|
||||
|
||||
assert req.ids == []
|
||||
|
||||
def test_none_ids(self):
|
||||
"""Test None IDs"""
|
||||
req = DeleteDatasetReq(ids=None)
|
||||
|
||||
assert req.ids is None
|
||||
|
||||
|
||||
class TestListDatasetReq:
|
||||
"""Test cases for ListDatasetReq validation"""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values for list request"""
|
||||
req = ListDatasetReq()
|
||||
|
||||
assert req.page == 1
|
||||
assert req.page_size == 30
|
||||
assert req.orderby == "create_time"
|
||||
assert req.desc is True
|
||||
|
||||
def test_custom_pagination(self):
|
||||
"""Test custom pagination values"""
|
||||
req = ListDatasetReq(page=2, page_size=50)
|
||||
|
||||
assert req.page == 2
|
||||
assert req.page_size == 50
|
||||
|
||||
def test_page_minimum_value(self):
|
||||
"""Test page minimum value validation"""
|
||||
with pytest.raises(ValidationError):
|
||||
ListDatasetReq(page=0)
|
||||
|
||||
def test_valid_orderby_values(self):
|
||||
"""Test valid orderby values"""
|
||||
req1 = ListDatasetReq(orderby="create_time")
|
||||
req2 = ListDatasetReq(orderby="update_time")
|
||||
|
||||
assert req1.orderby == "create_time"
|
||||
assert req2.orderby == "update_time"
|
||||
|
||||
def test_invalid_orderby_value(self):
|
||||
"""Test invalid orderby value raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
ListDatasetReq(orderby="invalid_field")
|
||||
|
||||
|
||||
class TestParserConfig:
|
||||
"""Test cases for ParserConfig validation"""
|
||||
|
||||
def test_default_parser_config(self):
|
||||
"""Test default parser configuration"""
|
||||
config = ParserConfig()
|
||||
|
||||
assert config.chunk_token_num == 512
|
||||
assert config.auto_keywords == 0
|
||||
assert config.auto_questions == 0
|
||||
|
||||
def test_custom_parser_config(self):
|
||||
"""Test custom parser configuration"""
|
||||
config = ParserConfig(
|
||||
chunk_token_num=1024,
|
||||
auto_keywords=5,
|
||||
auto_questions=3
|
||||
)
|
||||
|
||||
assert config.chunk_token_num == 1024
|
||||
assert config.auto_keywords == 5
|
||||
assert config.auto_questions == 3
|
||||
|
||||
def test_chunk_token_num_range(self):
|
||||
"""Test chunk_token_num range validation"""
|
||||
with pytest.raises(ValidationError):
|
||||
ParserConfig(chunk_token_num=3000) # Max is 2048
|
||||
|
||||
def test_raptor_config_integration(self):
|
||||
"""Test raptor config integration"""
|
||||
config = ParserConfig(
|
||||
raptor=RaptorConfig(use_raptor=True, max_token=512)
|
||||
)
|
||||
|
||||
assert config.raptor.use_raptor is True
|
||||
assert config.raptor.max_token == 512
|
||||
|
||||
|
||||
class TestRaptorConfig:
|
||||
"""Test cases for RaptorConfig validation"""
|
||||
|
||||
def test_default_raptor_config(self):
|
||||
"""Test default raptor configuration"""
|
||||
config = RaptorConfig()
|
||||
|
||||
assert config.use_raptor is False
|
||||
assert config.max_token == 256
|
||||
assert config.threshold == 0.1
|
||||
|
||||
def test_custom_raptor_config(self):
|
||||
"""Test custom raptor configuration"""
|
||||
config = RaptorConfig(
|
||||
use_raptor=True,
|
||||
max_token=512,
|
||||
threshold=0.2
|
||||
)
|
||||
|
||||
assert config.use_raptor is True
|
||||
assert config.max_token == 512
|
||||
assert config.threshold == 0.2
|
||||
|
||||
def test_threshold_range(self):
|
||||
"""Test threshold range validation"""
|
||||
with pytest.raises(ValidationError):
|
||||
RaptorConfig(threshold=1.5) # Max is 1.0
|
||||
|
||||
|
||||
class TestGraphragConfig:
|
||||
"""Test cases for GraphragConfig validation"""
|
||||
|
||||
def test_default_graphrag_config(self):
|
||||
"""Test default graphrag configuration"""
|
||||
config = GraphragConfig()
|
||||
|
||||
assert config.use_graphrag is False
|
||||
assert config.method == "light"
|
||||
assert "organization" in config.entity_types
|
||||
|
||||
def test_custom_graphrag_config(self):
|
||||
"""Test custom graphrag configuration"""
|
||||
config = GraphragConfig(
|
||||
use_graphrag=True,
|
||||
method="general",
|
||||
entity_types=["person", "location"]
|
||||
)
|
||||
|
||||
assert config.use_graphrag is True
|
||||
assert config.method == "general"
|
||||
assert config.entity_types == ["person", "location"]
|
||||
|
||||
def test_invalid_method(self):
|
||||
"""Test invalid method raises error"""
|
||||
with pytest.raises(ValidationError):
|
||||
GraphragConfig(method="invalid")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
23
test/unit_test/conftest.py
Normal file
23
test/unit_test/conftest.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
Loading…
Add table
Reference in a new issue