This commit is contained in:
Mobile-Crest 2025-12-14 09:53:31 +04:30 committed by GitHub
commit db6f29e529
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 1629 additions and 0 deletions

View file

@ -0,0 +1,15 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View file

@ -0,0 +1,120 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for api.utils.common module.
"""
import pytest
from api.utils.common import string_to_bytes, bytes_to_string
class TestStringToBytes:
"""Test cases for string_to_bytes function"""
def test_string_input_returns_bytes(self):
"""Test that string input is converted to bytes"""
input_string = "hello world"
result = string_to_bytes(input_string)
assert isinstance(result, bytes)
assert result == b"hello world"
def test_bytes_input_returns_same_bytes(self):
"""Test that bytes input is returned unchanged"""
input_bytes = b"hello world"
result = string_to_bytes(input_bytes)
assert isinstance(result, bytes)
assert result == input_bytes
assert result is input_bytes # Should be the same object
@pytest.mark.parametrize("input_val,expected", [
("test", b"test"),
("", b""),
("123", b"123"),
("Hello World", b"Hello World"),
("Hello 世界 🌍", "Hello 世界 🌍".encode("utf-8")),
("Hello, world! @#$%^&*()", b"Hello, world! @#$%^&*()"),
("Newline\nTab\tQuote\"", b"Newline\nTab\tQuote\""),
])
def test_various_string_inputs(self, input_val, expected):
"""Test conversion of various string inputs including unicode and special characters"""
result = string_to_bytes(input_val)
assert isinstance(result, bytes)
assert result == expected
class TestBytesToString:
"""Test cases for bytes_to_string function"""
@pytest.mark.parametrize("input_bytes,expected", [
(b"hello world", "hello world"),
(b"test", "test"),
(b"", ""),
(b"123", "123"),
(b"Hello World", "Hello World"),
("Hello 世界 🌍".encode("utf-8"), "Hello 世界 🌍"),
(b"Special: @#$%^&*()", "Special: @#$%^&*()"),
])
def test_various_bytes_inputs(self, input_bytes, expected):
"""Test conversion of various bytes inputs including unicode"""
result = bytes_to_string(input_bytes)
assert isinstance(result, str)
assert result == expected
def test_invalid_utf8_raises_error(self):
"""Test that invalid UTF-8 bytes raise an error"""
# Invalid UTF-8 sequence
invalid_bytes = b"\xff\xfe"
with pytest.raises(UnicodeDecodeError):
bytes_to_string(invalid_bytes)
class TestRoundtripConversion:
"""Test roundtrip conversions between string and bytes"""
@pytest.mark.parametrize("test_string", [
"Simple text",
"Hello, World! 世界",
"Unicode: 你好世界 🌍",
"Special: !@#$%^&*()",
"Multiline\nWith\tTabs",
"",
])
def test_string_to_bytes_to_string(self, test_string):
"""Test converting string to bytes and back for various inputs"""
as_bytes = string_to_bytes(test_string)
back_to_string = bytes_to_string(as_bytes)
assert back_to_string == test_string
@pytest.mark.parametrize("test_bytes", [
b"Simple text",
b"Hello, World!",
"Unicode: 你好世界 🌍".encode("utf-8"),
b"Special: !@#$%^&*()",
b"",
])
def test_bytes_to_string_to_bytes(self, test_bytes):
"""Test converting bytes to string and back for various inputs"""
as_string = bytes_to_string(test_bytes)
back_to_bytes = string_to_bytes(as_string)
assert back_to_bytes == test_bytes
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,231 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for api.utils.configs module.
"""
import pytest
import pickle
import base64
from unittest.mock import patch
from api.utils.configs import (
serialize_b64,
deserialize_b64,
RestrictedUnpickler,
restricted_loads,
safe_module
)
class TestSerializeB64:
"""Test cases for serialize_b64 function"""
@pytest.mark.parametrize("test_data", [
{"key": "value", "number": 42},
[1, 2, 3, "test", {"nested": "dict"}],
"Hello, World!",
12345,
{
"list": [1, 2, 3],
"dict": {"nested": {"deep": "value"}},
"tuple": (1, 2, 3),
"string": "test",
"number": 42.5
},
None,
{},
[],
])
def test_serialize_returns_bytes(self, test_data):
"""Test serialization of various data types returns bytes"""
result = serialize_b64(test_data, to_str=False)
assert isinstance(result, bytes)
# Should be valid base64
decoded = base64.b64decode(result)
assert isinstance(decoded, bytes)
def test_serialize_with_to_str_true(self):
"""Test serialization with to_str=True returns string"""
test_data = {"test": "data"}
result = serialize_b64(test_data, to_str=True)
assert isinstance(result, str)
# Should be valid base64 string
base64.b64decode(result) # Should not raise
class TestDeserializeB64:
"""Test cases for deserialize_b64 function"""
@pytest.mark.parametrize("to_str", [True, False])
def test_deserialize_string_and_bytes_input(self, to_str):
"""Test deserialization with both string and bytes input"""
test_data = {"key": "value"}
serialized = serialize_b64(test_data, to_str=to_str)
result = deserialize_b64(serialized)
assert result == test_data
@patch('api.utils.configs.get_base_config')
def test_deserialize_with_safe_module_disabled(self, mock_config):
"""Test deserialization with safe module checking disabled"""
mock_config.return_value = False
test_data = {"test": "data"}
serialized = serialize_b64(test_data, to_str=True)
result = deserialize_b64(serialized)
assert result == test_data
mock_config.assert_called_once_with('use_deserialize_safe_module', False)
@patch('api.utils.configs.get_base_config')
def test_deserialize_with_safe_module_enabled(self, mock_config):
"""Test deserialization with safe module checking enabled"""
mock_config.return_value = True
# Simple data that doesn't require unsafe modules
test_data = {"test": "data", "number": 42}
serialized = serialize_b64(test_data, to_str=True)
result = deserialize_b64(serialized)
assert result == test_data
@pytest.mark.parametrize("test_data", [
{"key": "value"},
{
"string": "test",
"number": 123,
"list": [1, 2, 3],
"nested": {"key": "value"}
},
[1, 2, 3, 4, 5],
"simple string",
42,
3.14,
None,
{"nested": {"deep": {"structure": "value"}}},
])
def test_roundtrip_various_data_types(self, test_data):
"""Test roundtrip serialization and deserialization for various data types"""
serialized = serialize_b64(test_data)
deserialized = deserialize_b64(serialized)
assert deserialized == test_data
class TestRestrictedUnpickler:
"""Test cases for RestrictedUnpickler class"""
@pytest.mark.parametrize("test_data", [
{"test": "data"},
[1, 2, 3, "test", {"key": "value"}],
{"nested": {"deep": "structure"}},
[1, 2, 3],
"simple string",
])
def test_restricted_loads_with_safe_data(self, test_data):
"""Test restricted_loads with various safe data types"""
pickled = pickle.dumps(test_data)
result = restricted_loads(pickled)
assert result == test_data
@patch('api.utils.configs.get_base_config')
def test_blocks_unsafe_modules(self, mock_config):
"""Test that unsafe modules are blocked"""
mock_config.return_value = True
# Try to pickle something from an unsafe module
# We'll simulate this by creating a pickle that references os module
class UnsafeClass:
__module__ = 'os.path'
def __reduce__(self):
return (eval, ("1+1",))
try:
unsafe_obj = UnsafeClass()
pickled = pickle.dumps(unsafe_obj)
with pytest.raises(pickle.UnpicklingError):
restricted_loads(pickled)
except:
# If we can't create the unsafe pickle, skip this test
pytest.skip("Unable to create unsafe pickle for testing")
def test_safe_module_set_contains_expected_modules(self):
"""Test that safe_module set contains expected modules"""
assert 'numpy' in safe_module
assert 'rag_flow' in safe_module
class TestIntegrationScenarios:
"""Integration tests for serialization/deserialization workflows"""
@pytest.mark.parametrize("to_str,original_data", [
(True, {
"user": "test_user",
"settings": {
"theme": "dark",
"notifications": True
},
"items": [1, 2, 3, 4, 5]
}),
(False, {"test": "data", "number": 42}),
(True, {}),
(True, {
f"key_{i}": {
"value": i,
"list": list(range(10)),
"nested": {"deep": f"value_{i}"}
}
for i in range(100)
}),
])
def test_serialize_deserialize_workflow(self, to_str, original_data):
"""Test complete workflow of serialize and deserialize with various data"""
# Serialize
serialized = serialize_b64(original_data, to_str=to_str)
if to_str:
assert isinstance(serialized, str)
else:
assert isinstance(serialized, bytes)
# Deserialize back
deserialized = deserialize_b64(serialized)
assert deserialized == original_data
@patch('api.utils.configs.get_base_config')
def test_safe_deserialization_workflow(self, mock_config):
"""Test safe deserialization workflow"""
mock_config.return_value = True
test_data = {"safe": "data"}
serialized = serialize_b64(test_data, to_str=True)
result = deserialize_b64(serialized)
assert result == test_data
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,452 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for api.utils.json_encode module.
"""
import pytest
import json
import datetime
from enum import Enum, IntEnum
from api.utils.json_encode import (
BaseType,
CustomJSONEncoder,
json_dumps,
json_loads
)
class TestBaseTypeToDict:
"""Test cases for BaseType.to_dict method"""
def test_simple_object_to_dict(self):
"""Test converting simple object to dictionary"""
class SimpleType(BaseType):
def __init__(self):
self.name = "test"
self.value = 42
obj = SimpleType()
result = obj.to_dict()
assert isinstance(result, dict)
assert result == {"name": "test", "value": 42}
def test_private_attributes_excluded(self):
"""Test that private attributes (starting with _) are stripped"""
class PrivateType(BaseType):
def __init__(self):
self._private = "hidden"
self._another = "also hidden"
self.public = "visible"
obj = PrivateType()
result = obj.to_dict()
# Private attributes should have _ stripped in keys
assert "private" in result
assert "another" in result
assert "public" in result
assert result["private"] == "hidden"
assert result["another"] == "also hidden"
assert result["public"] == "visible"
def test_empty_object(self):
"""Test converting empty object to dictionary"""
class EmptyType(BaseType):
pass
obj = EmptyType()
result = obj.to_dict()
assert isinstance(result, dict)
assert len(result) == 0
def test_nested_object_to_dict(self):
"""Test converting object with nested values"""
class NestedType(BaseType):
def __init__(self):
self.data = {"key": "value"}
self.items = [1, 2, 3]
self.text = "test"
obj = NestedType()
result = obj.to_dict()
assert result["data"] == {"key": "value"}
assert result["items"] == [1, 2, 3]
assert result["text"] == "test"
class TestBaseTypeToDictWithType:
"""Test cases for BaseType.to_dict_with_type method"""
def test_includes_type_information(self):
"""Test that type information is included"""
class TypedObject(BaseType):
def __init__(self):
self.value = "test"
obj = TypedObject()
result = obj.to_dict_with_type()
assert "type" in result
assert "data" in result
assert result["type"] == "TypedObject"
def test_includes_module_information(self):
"""Test that module information is included"""
class ModuleObject(BaseType):
def __init__(self):
self.value = "test"
obj = ModuleObject()
result = obj.to_dict_with_type()
assert "module" in result
assert result["module"] is not None
def test_nested_objects_with_types(self):
"""Test nested BaseType objects include type info"""
class InnerType(BaseType):
def __init__(self):
self.inner_value = "inner"
class OuterType(BaseType):
def __init__(self):
self.nested = InnerType()
self.value = "outer"
obj = OuterType()
result = obj.to_dict_with_type()
assert result["type"] == "OuterType"
assert "data" in result
assert "nested" in result["data"]
assert result["data"]["nested"]["type"] == "InnerType"
def test_list_handling(self):
"""Test handling of lists in to_dict_with_type"""
class ListType(BaseType):
def __init__(self):
self.items = [1, 2, 3]
obj = ListType()
result = obj.to_dict_with_type()
assert "items" in result["data"]
items_data = result["data"]["items"]
assert items_data["type"] == "list"
assert isinstance(items_data["data"], list)
def test_dict_handling(self):
"""Test handling of dictionaries in to_dict_with_type"""
class DictType(BaseType):
def __init__(self):
self.config = {"key": "value"}
obj = DictType()
result = obj.to_dict_with_type()
assert "config" in result["data"]
config_data = result["data"]["config"]
assert config_data["type"] == "dict"
class TestCustomJSONEncoder:
"""Test cases for CustomJSONEncoder class"""
def test_encode_datetime(self):
"""Test encoding of datetime objects"""
dt = datetime.datetime(2025, 12, 3, 14, 30, 45)
result = json.dumps(dt, cls=CustomJSONEncoder)
assert result == '"2025-12-03 14:30:45"'
def test_encode_date(self):
"""Test encoding of date objects"""
d = datetime.date(2025, 12, 3)
result = json.dumps(d, cls=CustomJSONEncoder)
assert result == '"2025-12-03"'
def test_encode_timedelta(self):
"""Test encoding of timedelta objects"""
td = datetime.timedelta(days=1, hours=2, minutes=30)
result = json.dumps(td, cls=CustomJSONEncoder)
assert isinstance(json.loads(result), str)
assert "1 day" in json.loads(result)
def test_encode_enum(self):
"""Test encoding of Enum objects"""
class Color(Enum):
RED = "red"
GREEN = "green"
BLUE = "blue"
result = json.dumps(Color.RED, cls=CustomJSONEncoder)
assert result == '"red"'
def test_encode_int_enum(self):
"""Test encoding of IntEnum objects"""
class Priority(IntEnum):
LOW = 1
MEDIUM = 2
HIGH = 3
result = json.dumps(Priority.HIGH, cls=CustomJSONEncoder)
assert result == '3'
def test_encode_set(self):
"""Test encoding of set objects"""
test_set = {1, 2, 3, 4, 5}
result = json.dumps(test_set, cls=CustomJSONEncoder)
# Set should be converted to list
decoded = json.loads(result)
assert isinstance(decoded, list)
assert set(decoded) == test_set
def test_encode_basetype_object(self):
"""Test encoding of BaseType objects"""
class TestType(BaseType):
def __init__(self):
self.name = "test"
self.value = 42
obj = TestType()
result = json.dumps(obj, cls=CustomJSONEncoder)
decoded = json.loads(result)
assert decoded["name"] == "test"
assert decoded["value"] == 42
def test_encode_basetype_with_type_flag(self):
"""Test encoding BaseType with with_type flag"""
class TestType(BaseType):
def __init__(self):
self.value = "test"
obj = TestType()
encoder = CustomJSONEncoder(with_type=True)
result = json.dumps(obj, cls=CustomJSONEncoder, with_type=True)
decoded = json.loads(result)
assert "type" in decoded
assert "data" in decoded
def test_encode_type_object(self):
"""Test encoding of type objects"""
result = json.dumps(str, cls=CustomJSONEncoder)
assert result == '"str"'
def test_encode_nested_structures(self):
"""Test encoding of nested structures with various types"""
class TestType(BaseType):
def __init__(self):
self.name = "test"
data = {
"datetime": datetime.datetime(2025, 12, 3, 12, 0, 0),
"date": datetime.date(2025, 12, 3),
"set": {1, 2, 3},
"object": TestType(),
"list": [1, 2, 3]
}
result = json.dumps(data, cls=CustomJSONEncoder)
decoded = json.loads(result)
assert decoded["datetime"] == "2025-12-03 12:00:00"
assert decoded["date"] == "2025-12-03"
assert isinstance(decoded["set"], list)
assert decoded["object"]["name"] == "test"
class TestJsonDumps:
"""Test cases for json_dumps function"""
def test_json_dumps_basic(self):
"""Test basic json_dumps functionality"""
data = {"key": "value", "number": 42}
result = json_dumps(data)
assert isinstance(result, str)
assert json.loads(result) == data
def test_json_dumps_with_byte_false(self):
"""Test json_dumps with byte=False returns string"""
data = {"test": "data"}
result = json_dumps(data, byte=False)
assert isinstance(result, str)
def test_json_dumps_with_byte_true(self):
"""Test json_dumps with byte=True returns bytes"""
data = {"test": "data"}
result = json_dumps(data, byte=True)
assert isinstance(result, bytes)
def test_json_dumps_with_indent(self):
"""Test json_dumps with indentation"""
data = {"key": "value"}
result = json_dumps(data, indent=2)
assert isinstance(result, str)
assert "\n" in result # Indented JSON has newlines
def test_json_dumps_with_type_false(self):
"""Test json_dumps with with_type=False"""
class TestType(BaseType):
def __init__(self):
self.value = "test"
obj = TestType()
result = json_dumps(obj, with_type=False)
decoded = json.loads(result)
assert "type" not in decoded
assert decoded["value"] == "test"
def test_json_dumps_with_type_true(self):
"""Test json_dumps with with_type=True"""
class TestType(BaseType):
def __init__(self):
self.value = "test"
obj = TestType()
result = json_dumps(obj, with_type=True)
decoded = json.loads(result)
assert "type" in decoded
assert "data" in decoded
def test_json_dumps_datetime(self):
"""Test json_dumps with datetime objects"""
data = {
"timestamp": datetime.datetime(2025, 12, 3, 15, 30, 0)
}
result = json_dumps(data)
decoded = json.loads(result)
assert decoded["timestamp"] == "2025-12-03 15:30:00"
class TestJsonLoads:
"""Test cases for json_loads function"""
def test_json_loads_string_input(self):
"""Test json_loads with string input"""
json_string = '{"key": "value", "number": 42}'
result = json_loads(json_string)
assert isinstance(result, dict)
assert result["key"] == "value"
assert result["number"] == 42
def test_json_loads_bytes_input(self):
"""Test json_loads with bytes input"""
json_bytes = b'{"key": "value"}'
result = json_loads(json_bytes)
assert isinstance(result, dict)
assert result["key"] == "value"
def test_json_loads_with_object_hook(self):
"""Test json_loads with object_hook parameter"""
def custom_hook(obj):
if "special" in obj:
obj["processed"] = True
return obj
json_string = '{"special": "value"}'
result = json_loads(json_string, object_hook=custom_hook)
assert result["processed"] is True
def test_json_loads_empty_object(self):
"""Test json_loads with empty object"""
result = json_loads('{}')
assert isinstance(result, dict)
assert len(result) == 0
def test_json_loads_array(self):
"""Test json_loads with array"""
result = json_loads('[1, 2, 3, 4, 5]')
assert isinstance(result, list)
assert result == [1, 2, 3, 4, 5]
class TestRoundtripConversion:
"""Test roundtrip conversions between dumps and loads"""
def test_roundtrip_dict(self):
"""Test roundtrip conversion of dictionary"""
original = {"key": "value", "number": 42, "list": [1, 2, 3]}
dumped = json_dumps(original)
loaded = json_loads(dumped)
assert loaded == original
def test_roundtrip_with_bytes(self):
"""Test roundtrip with byte conversion"""
original = {"test": "data"}
dumped = json_dumps(original, byte=True)
loaded = json_loads(dumped)
assert loaded == original
def test_roundtrip_basetype(self):
"""Test roundtrip with BaseType object"""
class TestType(BaseType):
def __init__(self):
self.name = "test"
self.value = 42
obj = TestType()
dumped = json_dumps(obj)
loaded = json_loads(dumped)
assert loaded["name"] == "test"
assert loaded["value"] == 42
@pytest.mark.parametrize("test_data", [
{"simple": "dict"},
[1, 2, 3, 4, 5],
{"nested": {"deep": {"structure": "value"}}},
{"mixed": [1, "two", 3.0, {"four": 4}]},
])
def test_roundtrip_various_structures(self, test_data):
"""Test roundtrip for various data structures"""
dumped = json_dumps(test_data)
loaded = json_loads(dumped)
assert loaded == test_data
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,788 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for api.utils.validation_utils module.
"""
import pytest
from unittest.mock import Mock, AsyncMock
from uuid import UUID, uuid1
from pydantic import BaseModel, Field, ValidationError
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.utils.validation_utils import (
validate_and_parse_json_request,
validate_and_parse_request_args,
format_validation_error_message,
normalize_str,
validate_uuid1_hex,
CreateDatasetReq,
UpdateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
ParserConfig,
RaptorConfig,
GraphragConfig,
)
class TestNormalizeStr:
"""Test cases for normalize_str function"""
def test_normalize_string_with_whitespace(self):
"""Test normalization of string with leading/trailing whitespace"""
result = normalize_str(" Admin ")
assert result == "admin"
def test_normalize_string_uppercase(self):
"""Test normalization converts to lowercase"""
result = normalize_str("UPPERCASE")
assert result == "uppercase"
def test_normalize_mixed_case(self):
"""Test normalization of mixed case string"""
result = normalize_str(" MiXeD CaSe ")
assert result == "mixed case"
def test_normalize_empty_string(self):
"""Test normalization of empty string"""
result = normalize_str("")
assert result == ""
def test_normalize_whitespace_only(self):
"""Test normalization of whitespace-only string"""
result = normalize_str(" ")
assert result == ""
def test_preserve_non_string_integer(self):
"""Test that integers are preserved"""
result = normalize_str(42)
assert result == 42
def test_preserve_non_string_none(self):
"""Test that None is preserved"""
result = normalize_str(None)
assert result is None
def test_preserve_non_string_list(self):
"""Test that lists are preserved"""
input_list = ["User", "Admin"]
result = normalize_str(input_list)
assert result == input_list
def test_preserve_non_string_dict(self):
"""Test that dicts are preserved"""
input_dict = {"role": "Admin"}
result = normalize_str(input_dict)
assert result == input_dict
@pytest.mark.parametrize("input_val,expected", [
("ReadOnly", "readonly"),
(" ADMIN ", "admin"),
("User", "user"),
("", ""),
(123, 123),
(False, False),
(0, 0),
])
def test_various_inputs(self, input_val, expected):
"""Test various input types"""
result = normalize_str(input_val)
assert result == expected
class TestValidateUuid1Hex:
"""Test cases for validate_uuid1_hex function"""
def test_valid_uuid1_string(self):
"""Test validation of valid UUID1 string"""
uuid1_obj = uuid1()
uuid1_str = str(uuid1_obj)
result = validate_uuid1_hex(uuid1_str)
assert isinstance(result, str)
assert len(result) == 32
assert result == uuid1_obj.hex
def test_valid_uuid1_object(self):
"""Test validation of valid UUID1 object"""
uuid1_obj = uuid1()
result = validate_uuid1_hex(uuid1_obj)
assert isinstance(result, str)
assert result == uuid1_obj.hex
def test_uuid1_hex_no_hyphens(self):
"""Test that result has no hyphens"""
uuid1_obj = uuid1()
result = validate_uuid1_hex(uuid1_obj)
assert "-" not in result
def test_invalid_uuid_string_raises_error(self):
"""Test that invalid UUID string raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError) as exc_info:
validate_uuid1_hex("not-a-uuid")
assert exc_info.value.type == "invalid_UUID1_format"
def test_non_uuid1_version_raises_error(self):
"""Test that non-UUID1 version raises error"""
from pydantic_core import PydanticCustomError
from uuid import uuid4
uuid4_obj = uuid4()
with pytest.raises(PydanticCustomError) as exc_info:
validate_uuid1_hex(uuid4_obj)
assert exc_info.value.type == "invalid_UUID1_format"
def test_integer_input_raises_error(self):
"""Test that integer input raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError):
validate_uuid1_hex(12345)
def test_none_input_raises_error(self):
"""Test that None input raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError):
validate_uuid1_hex(None)
class TestFormatValidationErrorMessage:
"""Test cases for format_validation_error_message function"""
def test_single_validation_error(self):
"""Test formatting of single validation error"""
class TestModel(BaseModel):
name: str
try:
TestModel(name=123)
except ValidationError as e:
result = format_validation_error_message(e)
assert "Field: <name>" in result
assert "Message:" in result
assert "Value: <123>" in result
def test_multiple_validation_errors(self):
"""Test formatting of multiple validation errors"""
class TestModel(BaseModel):
name: str
age: int
try:
TestModel(name=123, age="not_an_int")
except ValidationError as e:
result = format_validation_error_message(e)
assert "Field: <name>" in result
assert "Field: <age>" in result
assert "\n" in result # Multiple errors separated by newlines
def test_long_value_truncation(self):
"""Test that long values are truncated"""
class TestModel(BaseModel):
text: str
long_value = "x" * 200
try:
TestModel(text=123) # Wrong type to trigger error
except ValidationError as e:
# Manually create error with long value
pass
# Create a model with max_length constraint
class TestModel2(BaseModel):
text: str = Field(max_length=10)
try:
TestModel2(text="x" * 200)
except ValidationError as e:
result = format_validation_error_message(e)
# Check that value is truncated
assert "..." in result or len(result) < 500
def test_nested_field_error(self):
"""Test formatting of nested field errors"""
class NestedModel(BaseModel):
value: int
class ParentModel(BaseModel):
nested: NestedModel
try:
ParentModel(nested={"value": "not_an_int"})
except ValidationError as e:
result = format_validation_error_message(e)
assert "nested.value" in result
class TestValidateAndParseRequestArgs:
"""Test cases for validate_and_parse_request_args function"""
def test_valid_request_args(self):
"""Test validation of valid request arguments"""
class TestValidator(BaseModel):
param1: str
param2: int = 10
mock_request = Mock()
mock_request.args.to_dict.return_value = {"param1": "value", "param2": "20"}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert error is None
assert result is not None
assert result["param1"] == "value"
assert result["param2"] == 20
def test_missing_required_field(self):
"""Test validation with missing required field"""
class TestValidator(BaseModel):
required_field: str
mock_request = Mock()
mock_request.args.to_dict.return_value = {}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert result is None
assert error is not None
assert "required_field" in error
def test_with_extras_parameter(self):
"""Test validation with extras parameter"""
class TestValidator(BaseModel):
param1: str
internal_id: int
mock_request = Mock()
mock_request.args.to_dict.return_value = {"param1": "value"}
result, error = validate_and_parse_request_args(
mock_request,
TestValidator,
extras={"internal_id": 123}
)
assert error is None
assert result is not None
assert result["param1"] == "value"
assert "internal_id" not in result # Extras should be removed
def test_type_conversion(self):
"""Test that Pydantic performs type conversion"""
class TestValidator(BaseModel):
number: int
mock_request = Mock()
mock_request.args.to_dict.return_value = {"number": "42"}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert error is None
assert result["number"] == 42
assert isinstance(result["number"], int)
class TestValidateAndParseJsonRequest:
"""Test cases for validate_and_parse_json_request function"""
@pytest.mark.anyio
async def test_valid_json_request(self):
"""Test validation of valid JSON request"""
class TestValidator(BaseModel):
name: str
value: int
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test", "value": 42})
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert error is None
assert result is not None
assert result["name"] == "test"
assert result["value"] == 42
@pytest.mark.anyio
async def test_unsupported_content_type(self):
"""Test handling of unsupported content type"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(side_effect=UnsupportedMediaType())
mock_request.content_type = "text/xml"
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Unsupported content type" in error
assert "text/xml" in error
@pytest.mark.anyio
async def test_malformed_json(self):
"""Test handling of malformed JSON"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(side_effect=BadRequest())
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Malformed JSON syntax" in error
@pytest.mark.anyio
async def test_invalid_payload_type(self):
"""Test handling of non-dict payload"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value=["not", "a", "dict"])
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Invalid request payload" in error
assert "list" in error
@pytest.mark.anyio
async def test_validation_error(self):
"""Test handling of Pydantic validation errors"""
class TestValidator(BaseModel):
name: str
age: int
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": 123, "age": "not_int"})
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Field:" in error
@pytest.mark.anyio
async def test_with_extras_parameter(self):
"""Test validation with extras parameter"""
class TestValidator(BaseModel):
name: str
user_id: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test"})
result, error = await validate_and_parse_json_request(
mock_request,
TestValidator,
extras={"user_id": "user_123"}
)
assert error is None
assert result is not None
assert result["name"] == "test"
assert "user_id" not in result # Extras should be removed
@pytest.mark.anyio
async def test_exclude_unset_parameter(self):
"""Test exclude_unset parameter"""
class TestValidator(BaseModel):
name: str
optional: str = "default"
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test"})
result, error = await validate_and_parse_json_request(
mock_request,
TestValidator,
exclude_unset=True
)
assert error is None
assert result is not None
assert "name" in result
assert "optional" not in result # Not set, should be excluded
class TestCreateDatasetReq:
"""Test cases for CreateDatasetReq validation"""
def test_valid_dataset_creation(self):
"""Test valid dataset creation request"""
data = {
"name": "Test Dataset",
"embedding_model": "text-embedding-3-large@openai"
}
dataset = CreateDatasetReq(**data)
assert dataset.name == "Test Dataset"
assert dataset.embedding_model == "text-embedding-3-large@openai"
def test_name_whitespace_stripping(self):
"""Test that name whitespace is stripped"""
data = {
"name": " Test Dataset "
}
dataset = CreateDatasetReq(**data)
assert dataset.name == "Test Dataset"
def test_empty_name_raises_error(self):
"""Test that empty name raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(name="")
def test_invalid_embedding_model_format(self):
"""Test that invalid embedding model format raises error"""
with pytest.raises(ValidationError) as exc_info:
CreateDatasetReq(
name="Test",
embedding_model="invalid_model"
)
assert "format_invalid" in str(exc_info.value)
def test_embedding_model_without_provider(self):
"""Test embedding model without provider raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
embedding_model="model_name@"
)
def test_valid_avatar_base64(self):
"""Test valid base64 avatar"""
data = {
"name": "Test",
"avatar": ""
}
dataset = CreateDatasetReq(**data)
assert dataset.avatar.startswith(""
)
def test_avatar_missing_data_prefix(self):
"""Test avatar missing data: prefix raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
avatar="image/png;base64,abc123"
)
def test_default_chunk_method(self):
"""Test default chunk_method is set to 'naive'"""
dataset = CreateDatasetReq(name="Test")
assert dataset.chunk_method == "naive"
def test_valid_chunk_methods(self):
"""Test various valid chunk methods"""
valid_methods = ["naive", "book", "email", "laws", "manual", "one",
"paper", "picture", "presentation", "qa", "table", "tag"]
for method in valid_methods:
dataset = CreateDatasetReq(name="Test", chunk_method=method)
assert dataset.chunk_method == method
def test_invalid_chunk_method(self):
"""Test invalid chunk method raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(name="Test", chunk_method="invalid_method")
def test_pipeline_id_validation(self):
"""Test pipeline_id validation"""
# Valid 32-char hex
dataset = CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="a" * 32
)
assert dataset.pipeline_id == "a" * 32
def test_pipeline_id_wrong_length(self):
"""Test pipeline_id with wrong length raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="abc" # Too short
)
def test_pipeline_id_non_hex(self):
"""Test pipeline_id with non-hex characters raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="g" * 32 # 'g' is not hex
)
class TestUpdateDatasetReq:
"""Test cases for UpdateDatasetReq validation"""
def test_valid_update_request(self):
"""Test valid dataset update request"""
uuid1_obj = uuid1()
data = {
"dataset_id": str(uuid1_obj),
"name": "Updated Dataset"
}
dataset = UpdateDatasetReq(**data)
assert dataset.dataset_id == uuid1_obj.hex
assert dataset.name == "Updated Dataset"
def test_dataset_id_uuid1_validation(self):
"""Test that dataset_id must be UUID1"""
from uuid import uuid4
with pytest.raises(ValidationError):
UpdateDatasetReq(
dataset_id=str(uuid4()),
name="Test"
)
def test_pagerank_validation(self):
"""Test pagerank field validation"""
uuid1_obj = uuid1()
dataset = UpdateDatasetReq(
dataset_id=str(uuid1_obj),
name="Test",
pagerank=50
)
assert dataset.pagerank == 50
def test_pagerank_out_of_range(self):
"""Test pagerank out of range raises error"""
uuid1_obj = uuid1()
with pytest.raises(ValidationError):
UpdateDatasetReq(
dataset_id=str(uuid1_obj),
name="Test",
pagerank=101 # Max is 100
)
class TestDeleteDatasetReq:
"""Test cases for DeleteDatasetReq validation"""
def test_valid_delete_request(self):
"""Test valid delete request"""
uuid1_obj1 = uuid1()
uuid1_obj2 = uuid1()
req = DeleteDatasetReq(ids=[str(uuid1_obj1), str(uuid1_obj2)])
assert len(req.ids) == 2
assert uuid1_obj1.hex in req.ids
assert uuid1_obj2.hex in req.ids
def test_duplicate_ids_raises_error(self):
"""Test that duplicate IDs raise error"""
uuid1_obj = uuid1()
with pytest.raises(ValidationError) as exc_info:
DeleteDatasetReq(ids=[str(uuid1_obj), str(uuid1_obj)])
assert "duplicate" in str(exc_info.value).lower()
def test_empty_ids_list(self):
"""Test empty IDs list"""
req = DeleteDatasetReq(ids=[])
assert req.ids == []
def test_none_ids(self):
"""Test None IDs"""
req = DeleteDatasetReq(ids=None)
assert req.ids is None
class TestListDatasetReq:
"""Test cases for ListDatasetReq validation"""
def test_default_values(self):
"""Test default values for list request"""
req = ListDatasetReq()
assert req.page == 1
assert req.page_size == 30
assert req.orderby == "create_time"
assert req.desc is True
def test_custom_pagination(self):
"""Test custom pagination values"""
req = ListDatasetReq(page=2, page_size=50)
assert req.page == 2
assert req.page_size == 50
def test_page_minimum_value(self):
"""Test page minimum value validation"""
with pytest.raises(ValidationError):
ListDatasetReq(page=0)
def test_valid_orderby_values(self):
"""Test valid orderby values"""
req1 = ListDatasetReq(orderby="create_time")
req2 = ListDatasetReq(orderby="update_time")
assert req1.orderby == "create_time"
assert req2.orderby == "update_time"
def test_invalid_orderby_value(self):
"""Test invalid orderby value raises error"""
with pytest.raises(ValidationError):
ListDatasetReq(orderby="invalid_field")
class TestParserConfig:
"""Test cases for ParserConfig validation"""
def test_default_parser_config(self):
"""Test default parser configuration"""
config = ParserConfig()
assert config.chunk_token_num == 512
assert config.auto_keywords == 0
assert config.auto_questions == 0
def test_custom_parser_config(self):
"""Test custom parser configuration"""
config = ParserConfig(
chunk_token_num=1024,
auto_keywords=5,
auto_questions=3
)
assert config.chunk_token_num == 1024
assert config.auto_keywords == 5
assert config.auto_questions == 3
def test_chunk_token_num_range(self):
"""Test chunk_token_num range validation"""
with pytest.raises(ValidationError):
ParserConfig(chunk_token_num=3000) # Max is 2048
def test_raptor_config_integration(self):
"""Test raptor config integration"""
config = ParserConfig(
raptor=RaptorConfig(use_raptor=True, max_token=512)
)
assert config.raptor.use_raptor is True
assert config.raptor.max_token == 512
class TestRaptorConfig:
"""Test cases for RaptorConfig validation"""
def test_default_raptor_config(self):
"""Test default raptor configuration"""
config = RaptorConfig()
assert config.use_raptor is False
assert config.max_token == 256
assert config.threshold == 0.1
def test_custom_raptor_config(self):
"""Test custom raptor configuration"""
config = RaptorConfig(
use_raptor=True,
max_token=512,
threshold=0.2
)
assert config.use_raptor is True
assert config.max_token == 512
assert config.threshold == 0.2
def test_threshold_range(self):
"""Test threshold range validation"""
with pytest.raises(ValidationError):
RaptorConfig(threshold=1.5) # Max is 1.0
class TestGraphragConfig:
"""Test cases for GraphragConfig validation"""
def test_default_graphrag_config(self):
"""Test default graphrag configuration"""
config = GraphragConfig()
assert config.use_graphrag is False
assert config.method == "light"
assert "organization" in config.entity_types
def test_custom_graphrag_config(self):
"""Test custom graphrag configuration"""
config = GraphragConfig(
use_graphrag=True,
method="general",
entity_types=["person", "location"]
)
assert config.use_graphrag is True
assert config.method == "general"
assert config.entity_types == ["person", "location"]
def test_invalid_method(self):
"""Test invalid method raises error"""
with pytest.raises(ValidationError):
GraphragConfig(method="invalid")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,23 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
from pathlib import Path
# Add project root to Python path
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))