ragflow/test/unit_test/api/test_validation_utils.py
2025-12-03 12:27:37 +01:00

788 lines
25 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for api.utils.validation_utils module.
"""
import pytest
from unittest.mock import Mock, AsyncMock
from uuid import UUID, uuid1
from pydantic import BaseModel, Field, ValidationError
from werkzeug.exceptions import BadRequest, UnsupportedMediaType
from api.utils.validation_utils import (
validate_and_parse_json_request,
validate_and_parse_request_args,
format_validation_error_message,
normalize_str,
validate_uuid1_hex,
CreateDatasetReq,
UpdateDatasetReq,
DeleteDatasetReq,
ListDatasetReq,
ParserConfig,
RaptorConfig,
GraphragConfig,
)
class TestNormalizeStr:
"""Test cases for normalize_str function"""
def test_normalize_string_with_whitespace(self):
"""Test normalization of string with leading/trailing whitespace"""
result = normalize_str(" Admin ")
assert result == "admin"
def test_normalize_string_uppercase(self):
"""Test normalization converts to lowercase"""
result = normalize_str("UPPERCASE")
assert result == "uppercase"
def test_normalize_mixed_case(self):
"""Test normalization of mixed case string"""
result = normalize_str(" MiXeD CaSe ")
assert result == "mixed case"
def test_normalize_empty_string(self):
"""Test normalization of empty string"""
result = normalize_str("")
assert result == ""
def test_normalize_whitespace_only(self):
"""Test normalization of whitespace-only string"""
result = normalize_str(" ")
assert result == ""
def test_preserve_non_string_integer(self):
"""Test that integers are preserved"""
result = normalize_str(42)
assert result == 42
def test_preserve_non_string_none(self):
"""Test that None is preserved"""
result = normalize_str(None)
assert result is None
def test_preserve_non_string_list(self):
"""Test that lists are preserved"""
input_list = ["User", "Admin"]
result = normalize_str(input_list)
assert result == input_list
def test_preserve_non_string_dict(self):
"""Test that dicts are preserved"""
input_dict = {"role": "Admin"}
result = normalize_str(input_dict)
assert result == input_dict
@pytest.mark.parametrize("input_val,expected", [
("ReadOnly", "readonly"),
(" ADMIN ", "admin"),
("User", "user"),
("", ""),
(123, 123),
(False, False),
(0, 0),
])
def test_various_inputs(self, input_val, expected):
"""Test various input types"""
result = normalize_str(input_val)
assert result == expected
class TestValidateUuid1Hex:
"""Test cases for validate_uuid1_hex function"""
def test_valid_uuid1_string(self):
"""Test validation of valid UUID1 string"""
uuid1_obj = uuid1()
uuid1_str = str(uuid1_obj)
result = validate_uuid1_hex(uuid1_str)
assert isinstance(result, str)
assert len(result) == 32
assert result == uuid1_obj.hex
def test_valid_uuid1_object(self):
"""Test validation of valid UUID1 object"""
uuid1_obj = uuid1()
result = validate_uuid1_hex(uuid1_obj)
assert isinstance(result, str)
assert result == uuid1_obj.hex
def test_uuid1_hex_no_hyphens(self):
"""Test that result has no hyphens"""
uuid1_obj = uuid1()
result = validate_uuid1_hex(uuid1_obj)
assert "-" not in result
def test_invalid_uuid_string_raises_error(self):
"""Test that invalid UUID string raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError) as exc_info:
validate_uuid1_hex("not-a-uuid")
assert exc_info.value.type == "invalid_UUID1_format"
def test_non_uuid1_version_raises_error(self):
"""Test that non-UUID1 version raises error"""
from pydantic_core import PydanticCustomError
from uuid import uuid4
uuid4_obj = uuid4()
with pytest.raises(PydanticCustomError) as exc_info:
validate_uuid1_hex(uuid4_obj)
assert exc_info.value.type == "invalid_UUID1_format"
def test_integer_input_raises_error(self):
"""Test that integer input raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError):
validate_uuid1_hex(12345)
def test_none_input_raises_error(self):
"""Test that None input raises error"""
from pydantic_core import PydanticCustomError
with pytest.raises(PydanticCustomError):
validate_uuid1_hex(None)
class TestFormatValidationErrorMessage:
"""Test cases for format_validation_error_message function"""
def test_single_validation_error(self):
"""Test formatting of single validation error"""
class TestModel(BaseModel):
name: str
try:
TestModel(name=123)
except ValidationError as e:
result = format_validation_error_message(e)
assert "Field: <name>" in result
assert "Message:" in result
assert "Value: <123>" in result
def test_multiple_validation_errors(self):
"""Test formatting of multiple validation errors"""
class TestModel(BaseModel):
name: str
age: int
try:
TestModel(name=123, age="not_an_int")
except ValidationError as e:
result = format_validation_error_message(e)
assert "Field: <name>" in result
assert "Field: <age>" in result
assert "\n" in result # Multiple errors separated by newlines
def test_long_value_truncation(self):
"""Test that long values are truncated"""
class TestModel(BaseModel):
text: str
long_value = "x" * 200
try:
TestModel(text=123) # Wrong type to trigger error
except ValidationError as e:
# Manually create error with long value
pass
# Create a model with max_length constraint
class TestModel2(BaseModel):
text: str = Field(max_length=10)
try:
TestModel2(text="x" * 200)
except ValidationError as e:
result = format_validation_error_message(e)
# Check that value is truncated
assert "..." in result or len(result) < 500
def test_nested_field_error(self):
"""Test formatting of nested field errors"""
class NestedModel(BaseModel):
value: int
class ParentModel(BaseModel):
nested: NestedModel
try:
ParentModel(nested={"value": "not_an_int"})
except ValidationError as e:
result = format_validation_error_message(e)
assert "nested.value" in result
class TestValidateAndParseRequestArgs:
"""Test cases for validate_and_parse_request_args function"""
def test_valid_request_args(self):
"""Test validation of valid request arguments"""
class TestValidator(BaseModel):
param1: str
param2: int = 10
mock_request = Mock()
mock_request.args.to_dict.return_value = {"param1": "value", "param2": "20"}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert error is None
assert result is not None
assert result["param1"] == "value"
assert result["param2"] == 20
def test_missing_required_field(self):
"""Test validation with missing required field"""
class TestValidator(BaseModel):
required_field: str
mock_request = Mock()
mock_request.args.to_dict.return_value = {}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert result is None
assert error is not None
assert "required_field" in error
def test_with_extras_parameter(self):
"""Test validation with extras parameter"""
class TestValidator(BaseModel):
param1: str
internal_id: int
mock_request = Mock()
mock_request.args.to_dict.return_value = {"param1": "value"}
result, error = validate_and_parse_request_args(
mock_request,
TestValidator,
extras={"internal_id": 123}
)
assert error is None
assert result is not None
assert result["param1"] == "value"
assert "internal_id" not in result # Extras should be removed
def test_type_conversion(self):
"""Test that Pydantic performs type conversion"""
class TestValidator(BaseModel):
number: int
mock_request = Mock()
mock_request.args.to_dict.return_value = {"number": "42"}
result, error = validate_and_parse_request_args(mock_request, TestValidator)
assert error is None
assert result["number"] == 42
assert isinstance(result["number"], int)
class TestValidateAndParseJsonRequest:
"""Test cases for validate_and_parse_json_request function"""
@pytest.mark.anyio
async def test_valid_json_request(self):
"""Test validation of valid JSON request"""
class TestValidator(BaseModel):
name: str
value: int
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test", "value": 42})
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert error is None
assert result is not None
assert result["name"] == "test"
assert result["value"] == 42
@pytest.mark.anyio
async def test_unsupported_content_type(self):
"""Test handling of unsupported content type"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(side_effect=UnsupportedMediaType())
mock_request.content_type = "text/xml"
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Unsupported content type" in error
assert "text/xml" in error
@pytest.mark.anyio
async def test_malformed_json(self):
"""Test handling of malformed JSON"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(side_effect=BadRequest())
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Malformed JSON syntax" in error
@pytest.mark.anyio
async def test_invalid_payload_type(self):
"""Test handling of non-dict payload"""
class TestValidator(BaseModel):
name: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value=["not", "a", "dict"])
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Invalid request payload" in error
assert "list" in error
@pytest.mark.anyio
async def test_validation_error(self):
"""Test handling of Pydantic validation errors"""
class TestValidator(BaseModel):
name: str
age: int
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": 123, "age": "not_int"})
result, error = await validate_and_parse_json_request(mock_request, TestValidator)
assert result is None
assert error is not None
assert "Field:" in error
@pytest.mark.anyio
async def test_with_extras_parameter(self):
"""Test validation with extras parameter"""
class TestValidator(BaseModel):
name: str
user_id: str
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test"})
result, error = await validate_and_parse_json_request(
mock_request,
TestValidator,
extras={"user_id": "user_123"}
)
assert error is None
assert result is not None
assert result["name"] == "test"
assert "user_id" not in result # Extras should be removed
@pytest.mark.anyio
async def test_exclude_unset_parameter(self):
"""Test exclude_unset parameter"""
class TestValidator(BaseModel):
name: str
optional: str = "default"
mock_request = AsyncMock()
mock_request.get_json = AsyncMock(return_value={"name": "test"})
result, error = await validate_and_parse_json_request(
mock_request,
TestValidator,
exclude_unset=True
)
assert error is None
assert result is not None
assert "name" in result
assert "optional" not in result # Not set, should be excluded
class TestCreateDatasetReq:
"""Test cases for CreateDatasetReq validation"""
def test_valid_dataset_creation(self):
"""Test valid dataset creation request"""
data = {
"name": "Test Dataset",
"embedding_model": "text-embedding-3-large@openai"
}
dataset = CreateDatasetReq(**data)
assert dataset.name == "Test Dataset"
assert dataset.embedding_model == "text-embedding-3-large@openai"
def test_name_whitespace_stripping(self):
"""Test that name whitespace is stripped"""
data = {
"name": " Test Dataset "
}
dataset = CreateDatasetReq(**data)
assert dataset.name == "Test Dataset"
def test_empty_name_raises_error(self):
"""Test that empty name raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(name="")
def test_invalid_embedding_model_format(self):
"""Test that invalid embedding model format raises error"""
with pytest.raises(ValidationError) as exc_info:
CreateDatasetReq(
name="Test",
embedding_model="invalid_model"
)
assert "format_invalid" in str(exc_info.value)
def test_embedding_model_without_provider(self):
"""Test embedding model without provider raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
embedding_model="model_name@"
)
def test_valid_avatar_base64(self):
"""Test valid base64 avatar"""
data = {
"name": "Test",
"avatar": ""
}
dataset = CreateDatasetReq(**data)
assert dataset.avatar.startswith(""
)
def test_avatar_missing_data_prefix(self):
"""Test avatar missing data: prefix raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
avatar="image/png;base64,abc123"
)
def test_default_chunk_method(self):
"""Test default chunk_method is set to 'naive'"""
dataset = CreateDatasetReq(name="Test")
assert dataset.chunk_method == "naive"
def test_valid_chunk_methods(self):
"""Test various valid chunk methods"""
valid_methods = ["naive", "book", "email", "laws", "manual", "one",
"paper", "picture", "presentation", "qa", "table", "tag"]
for method in valid_methods:
dataset = CreateDatasetReq(name="Test", chunk_method=method)
assert dataset.chunk_method == method
def test_invalid_chunk_method(self):
"""Test invalid chunk method raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(name="Test", chunk_method="invalid_method")
def test_pipeline_id_validation(self):
"""Test pipeline_id validation"""
# Valid 32-char hex
dataset = CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="a" * 32
)
assert dataset.pipeline_id == "a" * 32
def test_pipeline_id_wrong_length(self):
"""Test pipeline_id with wrong length raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="abc" # Too short
)
def test_pipeline_id_non_hex(self):
"""Test pipeline_id with non-hex characters raises error"""
with pytest.raises(ValidationError):
CreateDatasetReq(
name="Test",
parse_type=1,
pipeline_id="g" * 32 # 'g' is not hex
)
class TestUpdateDatasetReq:
"""Test cases for UpdateDatasetReq validation"""
def test_valid_update_request(self):
"""Test valid dataset update request"""
uuid1_obj = uuid1()
data = {
"dataset_id": str(uuid1_obj),
"name": "Updated Dataset"
}
dataset = UpdateDatasetReq(**data)
assert dataset.dataset_id == uuid1_obj.hex
assert dataset.name == "Updated Dataset"
def test_dataset_id_uuid1_validation(self):
"""Test that dataset_id must be UUID1"""
from uuid import uuid4
with pytest.raises(ValidationError):
UpdateDatasetReq(
dataset_id=str(uuid4()),
name="Test"
)
def test_pagerank_validation(self):
"""Test pagerank field validation"""
uuid1_obj = uuid1()
dataset = UpdateDatasetReq(
dataset_id=str(uuid1_obj),
name="Test",
pagerank=50
)
assert dataset.pagerank == 50
def test_pagerank_out_of_range(self):
"""Test pagerank out of range raises error"""
uuid1_obj = uuid1()
with pytest.raises(ValidationError):
UpdateDatasetReq(
dataset_id=str(uuid1_obj),
name="Test",
pagerank=101 # Max is 100
)
class TestDeleteDatasetReq:
"""Test cases for DeleteDatasetReq validation"""
def test_valid_delete_request(self):
"""Test valid delete request"""
uuid1_obj1 = uuid1()
uuid1_obj2 = uuid1()
req = DeleteDatasetReq(ids=[str(uuid1_obj1), str(uuid1_obj2)])
assert len(req.ids) == 2
assert uuid1_obj1.hex in req.ids
assert uuid1_obj2.hex in req.ids
def test_duplicate_ids_raises_error(self):
"""Test that duplicate IDs raise error"""
uuid1_obj = uuid1()
with pytest.raises(ValidationError) as exc_info:
DeleteDatasetReq(ids=[str(uuid1_obj), str(uuid1_obj)])
assert "duplicate" in str(exc_info.value).lower()
def test_empty_ids_list(self):
"""Test empty IDs list"""
req = DeleteDatasetReq(ids=[])
assert req.ids == []
def test_none_ids(self):
"""Test None IDs"""
req = DeleteDatasetReq(ids=None)
assert req.ids is None
class TestListDatasetReq:
"""Test cases for ListDatasetReq validation"""
def test_default_values(self):
"""Test default values for list request"""
req = ListDatasetReq()
assert req.page == 1
assert req.page_size == 30
assert req.orderby == "create_time"
assert req.desc is True
def test_custom_pagination(self):
"""Test custom pagination values"""
req = ListDatasetReq(page=2, page_size=50)
assert req.page == 2
assert req.page_size == 50
def test_page_minimum_value(self):
"""Test page minimum value validation"""
with pytest.raises(ValidationError):
ListDatasetReq(page=0)
def test_valid_orderby_values(self):
"""Test valid orderby values"""
req1 = ListDatasetReq(orderby="create_time")
req2 = ListDatasetReq(orderby="update_time")
assert req1.orderby == "create_time"
assert req2.orderby == "update_time"
def test_invalid_orderby_value(self):
"""Test invalid orderby value raises error"""
with pytest.raises(ValidationError):
ListDatasetReq(orderby="invalid_field")
class TestParserConfig:
"""Test cases for ParserConfig validation"""
def test_default_parser_config(self):
"""Test default parser configuration"""
config = ParserConfig()
assert config.chunk_token_num == 512
assert config.auto_keywords == 0
assert config.auto_questions == 0
def test_custom_parser_config(self):
"""Test custom parser configuration"""
config = ParserConfig(
chunk_token_num=1024,
auto_keywords=5,
auto_questions=3
)
assert config.chunk_token_num == 1024
assert config.auto_keywords == 5
assert config.auto_questions == 3
def test_chunk_token_num_range(self):
"""Test chunk_token_num range validation"""
with pytest.raises(ValidationError):
ParserConfig(chunk_token_num=3000) # Max is 2048
def test_raptor_config_integration(self):
"""Test raptor config integration"""
config = ParserConfig(
raptor=RaptorConfig(use_raptor=True, max_token=512)
)
assert config.raptor.use_raptor is True
assert config.raptor.max_token == 512
class TestRaptorConfig:
"""Test cases for RaptorConfig validation"""
def test_default_raptor_config(self):
"""Test default raptor configuration"""
config = RaptorConfig()
assert config.use_raptor is False
assert config.max_token == 256
assert config.threshold == 0.1
def test_custom_raptor_config(self):
"""Test custom raptor configuration"""
config = RaptorConfig(
use_raptor=True,
max_token=512,
threshold=0.2
)
assert config.use_raptor is True
assert config.max_token == 512
assert config.threshold == 0.2
def test_threshold_range(self):
"""Test threshold range validation"""
with pytest.raises(ValidationError):
RaptorConfig(threshold=1.5) # Max is 1.0
class TestGraphragConfig:
"""Test cases for GraphragConfig validation"""
def test_default_graphrag_config(self):
"""Test default graphrag configuration"""
config = GraphragConfig()
assert config.use_graphrag is False
assert config.method == "light"
assert "organization" in config.entity_types
def test_custom_graphrag_config(self):
"""Test custom graphrag configuration"""
config = GraphragConfig(
use_graphrag=True,
method="general",
entity_types=["person", "location"]
)
assert config.use_graphrag is True
assert config.method == "general"
assert config.entity_types == ["person", "location"]
def test_invalid_method(self):
"""Test invalid method raises error"""
with pytest.raises(ValidationError):
GraphragConfig(method="invalid")
if __name__ == "__main__":
pytest.main([__file__, "-v"])