ragflow/test/unit_test/common/test_connection_utils.py
2025-12-03 18:15:22 +01:00

340 lines
13 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import trio
import anyio
from common.connection_utils import timeout, construct_response, sync_construct_response
from common.constants import RetCode
class TestTimeoutDecorator:
"""Test cases for the timeout decorator"""
def test_sync_function_success(self):
"""Test timeout decorator with successful sync function"""
@timeout(seconds=5)
def fast_function():
return "success"
result = fast_function()
assert result == "success"
def test_sync_function_with_args(self):
"""Test timeout decorator with sync function that has arguments"""
@timeout(seconds=5)
def add_numbers(a, b):
return a + b
result = add_numbers(3, 4)
assert result == 7
def test_sync_function_raises_exception(self):
"""Test timeout decorator propagates exceptions from sync function"""
@timeout(seconds=5)
def failing_function():
raise ValueError("Test error")
with pytest.raises(ValueError, match="Test error"):
failing_function()
def test_sync_function_timeout_disabled(self):
"""Test sync function when timeout is disabled (no ENABLE_TIMEOUT_ASSERTION)"""
@timeout(seconds=1)
def slow_function():
time.sleep(0.1)
return "completed"
# Without ENABLE_TIMEOUT_ASSERTION, timeout should not trigger
result = slow_function()
assert result == "completed"
def test_sync_function_timeout_enabled(self):
"""Test sync function timeout when ENABLE_TIMEOUT_ASSERTION is set"""
@timeout(seconds=0.1, attempts=1)
def slow_function():
time.sleep(2)
return "should not reach"
with patch.dict(os.environ, {"ENABLE_TIMEOUT_ASSERTION": "1"}):
with pytest.raises(TimeoutError, match="timed out after 0.1 seconds and 1 attempts"):
slow_function()
def test_sync_function_with_string_seconds(self):
"""Test timeout decorator with string seconds parameter"""
@timeout(seconds="5")
def fast_function():
return "success"
result = fast_function()
assert result == "success"
def test_sync_function_multiple_attempts(self):
"""Test sync function with multiple attempts"""
call_count = 0
@timeout(seconds=0.1, attempts=3)
def function_with_attempts():
nonlocal call_count
call_count += 1
time.sleep(0.05)
return f"attempt_{call_count}"
result = function_with_attempts()
assert result == "attempt_1"
@pytest.mark.anyio
async def test_async_function_success(self):
"""Test timeout decorator with successful async function"""
@timeout(seconds=5)
async def fast_async_function():
await anyio.sleep(0.01)
return "async_success"
result = await fast_async_function()
assert result == "async_success"
@pytest.mark.anyio
async def test_async_function_with_args(self):
"""Test timeout decorator with async function that has arguments"""
@timeout(seconds=5)
async def async_multiply(x, y):
await anyio.sleep(0.01)
return x * y
result = await async_multiply(6, 7)
assert result == 42
@pytest.mark.anyio
async def test_async_function_none_timeout(self):
"""Test async function with None timeout (no timeout applied)"""
@timeout(seconds=None)
async def async_function():
await anyio.sleep(0.01)
return "no_timeout"
result = await async_function()
assert result == "no_timeout"
@pytest.mark.anyio
async def test_async_function_timeout_disabled(self):
"""Test async function when timeout is disabled"""
@timeout(seconds=1)
async def slow_async_function():
await anyio.sleep(0.1)
return "completed"
# Without ENABLE_TIMEOUT_ASSERTION, timeout should not trigger
result = await slow_async_function()
assert result == "completed"
@pytest.mark.anyio
async def test_async_function_multiple_attempts(self):
"""Test async function with multiple attempts"""
call_count = 0
@timeout(seconds=0.1, attempts=3)
async def function_with_attempts():
nonlocal call_count
call_count += 1
await anyio.sleep(0.05)
return f"attempt_{call_count}"
result = await function_with_attempts()
assert result == "attempt_1"
class TestConstructResponse:
"""Test cases for construct_response function"""
@pytest.mark.anyio
async def test_construct_response_default(self):
"""Test construct_response with default parameters"""
with patch('common.connection_utils.make_response', new_callable=AsyncMock) as mock_make_response:
with patch('common.connection_utils.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
mock_jsonify.return_value = {"code": RetCode.SUCCESS, "message": "success"}
response = await construct_response()
mock_jsonify.assert_called_once_with({"code": RetCode.SUCCESS, "message": "success"})
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Method"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "*"
assert response.headers["Access-Control-Expose-Headers"] == "Authorization"
@pytest.mark.anyio
async def test_construct_response_with_data(self):
"""Test construct_response with data parameter"""
with patch('common.connection_utils.make_response', new_callable=AsyncMock) as mock_make_response:
with patch('common.connection_utils.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
test_data = {"key": "value"}
response = await construct_response(data=test_data)
call_args = mock_jsonify.call_args[0][0]
assert call_args["code"] == RetCode.SUCCESS
assert call_args["message"] == "success"
assert call_args["data"] == test_data
@pytest.mark.anyio
async def test_construct_response_with_error_code(self):
"""Test construct_response with error code"""
with patch('common.connection_utils.make_response', new_callable=AsyncMock) as mock_make_response:
with patch('common.connection_utils.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
response = await construct_response(
code=RetCode.ARGUMENT_ERROR,
message="Invalid argument"
)
call_args = mock_jsonify.call_args[0][0]
assert call_args["code"] == RetCode.ARGUMENT_ERROR
assert call_args["message"] == "Invalid argument"
@pytest.mark.anyio
async def test_construct_response_with_auth(self):
"""Test construct_response with auth token"""
with patch('common.connection_utils.make_response', new_callable=AsyncMock) as mock_make_response:
with patch('common.connection_utils.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
auth_token = "Bearer test_token"
response = await construct_response(auth=auth_token)
assert response.headers["Authorization"] == auth_token
@pytest.mark.anyio
async def test_construct_response_none_values_excluded(self):
"""Test that None values are excluded from response (except code)"""
with patch('common.connection_utils.make_response', new_callable=AsyncMock) as mock_make_response:
with patch('common.connection_utils.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
response = await construct_response(
code=RetCode.SUCCESS,
message=None,
data=None
)
call_args = mock_jsonify.call_args[0][0]
assert "code" in call_args
assert "message" not in call_args
assert "data" not in call_args
class TestSyncConstructResponse:
"""Test cases for sync_construct_response function"""
def test_sync_construct_response_default(self):
"""Test sync_construct_response with default parameters"""
with patch('flask.make_response') as mock_make_response:
with patch('flask.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
mock_jsonify.return_value = {"code": RetCode.SUCCESS, "message": "success"}
response = sync_construct_response()
mock_jsonify.assert_called_once_with({"code": RetCode.SUCCESS, "message": "success"})
assert response.headers["Access-Control-Allow-Origin"] == "*"
assert response.headers["Access-Control-Allow-Method"] == "*"
assert response.headers["Access-Control-Allow-Headers"] == "*"
assert response.headers["Access-Control-Expose-Headers"] == "Authorization"
def test_sync_construct_response_with_data(self):
"""Test sync_construct_response with data parameter"""
with patch('flask.make_response') as mock_make_response:
with patch('flask.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
test_data = {"result": "test"}
response = sync_construct_response(data=test_data)
call_args = mock_jsonify.call_args[0][0]
assert call_args["code"] == RetCode.SUCCESS
assert call_args["message"] == "success"
assert call_args["data"] == test_data
def test_sync_construct_response_with_error_code(self):
"""Test sync_construct_response with error code"""
with patch('flask.make_response') as mock_make_response:
with patch('flask.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
response = sync_construct_response(
code=RetCode.SERVER_ERROR,
message="Server error occurred"
)
call_args = mock_jsonify.call_args[0][0]
assert call_args["code"] == RetCode.SERVER_ERROR
assert call_args["message"] == "Server error occurred"
def test_sync_construct_response_with_auth(self):
"""Test sync_construct_response with auth token"""
with patch('flask.make_response') as mock_make_response:
with patch('flask.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
auth_token = "Bearer sync_token"
response = sync_construct_response(auth=auth_token)
assert response.headers["Authorization"] == auth_token
def test_sync_construct_response_none_values_excluded(self):
"""Test that None values are excluded from response (except code)"""
with patch('flask.make_response') as mock_make_response:
with patch('flask.jsonify') as mock_jsonify:
mock_response = MagicMock()
mock_response.headers = {}
mock_make_response.return_value = mock_response
response = sync_construct_response(
code=RetCode.SUCCESS,
message=None,
data=None
)
call_args = mock_jsonify.call_args[0][0]
assert "code" in call_args
assert "message" not in call_args
assert "data" not in call_args