- Remove unused MagicMock import from test_dialog_service.py - Remove unused MagicMock import from test_conversation_service.py - Remove unused MagicMock import from test_canvas_service.py Fixes CI linting errors (F401: imported but unused)
299 lines
11 KiB
Python
299 lines
11 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from common.misc_utils import get_uuid
|
|
from common.constants import StatusEnum
|
|
|
|
|
|
class TestDialogService:
|
|
"""Comprehensive unit tests for DialogService"""
|
|
|
|
@pytest.fixture
|
|
def mock_dialog_service(self):
|
|
"""Create a mock DialogService for testing"""
|
|
with patch('api.db.services.dialog_service.DialogService') as mock:
|
|
yield mock
|
|
|
|
@pytest.fixture
|
|
def sample_dialog_data(self):
|
|
"""Sample dialog data for testing"""
|
|
return {
|
|
"id": get_uuid(),
|
|
"tenant_id": "test_tenant_123",
|
|
"name": "Test Dialog",
|
|
"description": "A test dialog",
|
|
"icon": "",
|
|
"llm_id": "gpt-4",
|
|
"llm_setting": {
|
|
"temperature": 0.1,
|
|
"top_p": 0.3,
|
|
"frequency_penalty": 0.7,
|
|
"presence_penalty": 0.4,
|
|
"max_tokens": 512
|
|
},
|
|
"prompt_config": {
|
|
"system": "You are a helpful assistant",
|
|
"prologue": "Hi! How can I help you?",
|
|
"parameters": [],
|
|
"empty_response": "Sorry! No relevant content found."
|
|
},
|
|
"kb_ids": [],
|
|
"top_n": 6,
|
|
"top_k": 1024,
|
|
"similarity_threshold": 0.2,
|
|
"vector_similarity_weight": 0.3,
|
|
"rerank_id": "",
|
|
"status": StatusEnum.VALID.value
|
|
}
|
|
|
|
def test_dialog_creation_success(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test successful dialog creation"""
|
|
mock_dialog_service.save.return_value = True
|
|
mock_dialog_service.get_by_id.return_value = (True, Mock(**sample_dialog_data))
|
|
|
|
result = mock_dialog_service.save(**sample_dialog_data)
|
|
assert result is True
|
|
mock_dialog_service.save.assert_called_once_with(**sample_dialog_data)
|
|
|
|
def test_dialog_creation_with_invalid_name(self, mock_dialog_service):
|
|
"""Test dialog creation with invalid name"""
|
|
invalid_data = {
|
|
"name": "", # Empty name
|
|
"tenant_id": "test_tenant"
|
|
}
|
|
|
|
# Should raise validation error
|
|
with pytest.raises(Exception):
|
|
if not invalid_data["name"].strip():
|
|
raise Exception("Dialog name can't be empty")
|
|
|
|
def test_dialog_creation_with_long_name(self, mock_dialog_service):
|
|
"""Test dialog creation with name exceeding 255 bytes"""
|
|
long_name = "a" * 300
|
|
|
|
with pytest.raises(Exception):
|
|
if len(long_name.encode("utf-8")) > 255:
|
|
raise Exception(f"Dialog name length is {len(long_name)} which is larger than 255")
|
|
|
|
def test_dialog_update_success(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test successful dialog update"""
|
|
dialog_id = sample_dialog_data["id"]
|
|
update_data = {"name": "Updated Dialog Name"}
|
|
|
|
mock_dialog_service.update_by_id.return_value = True
|
|
result = mock_dialog_service.update_by_id(dialog_id, update_data)
|
|
|
|
assert result is True
|
|
mock_dialog_service.update_by_id.assert_called_once_with(dialog_id, update_data)
|
|
|
|
def test_dialog_update_nonexistent(self, mock_dialog_service):
|
|
"""Test updating a non-existent dialog"""
|
|
mock_dialog_service.update_by_id.return_value = False
|
|
|
|
result = mock_dialog_service.update_by_id("nonexistent_id", {"name": "test"})
|
|
assert result is False
|
|
|
|
def test_dialog_get_by_id_success(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test retrieving dialog by ID"""
|
|
dialog_id = sample_dialog_data["id"]
|
|
mock_dialog = Mock()
|
|
mock_dialog.to_dict.return_value = sample_dialog_data
|
|
|
|
mock_dialog_service.get_by_id.return_value = (True, mock_dialog)
|
|
|
|
exists, dialog = mock_dialog_service.get_by_id(dialog_id)
|
|
assert exists is True
|
|
assert dialog.to_dict() == sample_dialog_data
|
|
|
|
def test_dialog_get_by_id_not_found(self, mock_dialog_service):
|
|
"""Test retrieving non-existent dialog"""
|
|
mock_dialog_service.get_by_id.return_value = (False, None)
|
|
|
|
exists, dialog = mock_dialog_service.get_by_id("nonexistent_id")
|
|
assert exists is False
|
|
assert dialog is None
|
|
|
|
def test_dialog_list_by_tenant(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test listing dialogs by tenant ID"""
|
|
tenant_id = "test_tenant_123"
|
|
mock_dialogs = [Mock(to_dict=lambda: sample_dialog_data) for _ in range(3)]
|
|
|
|
mock_dialog_service.query.return_value = mock_dialogs
|
|
|
|
result = mock_dialog_service.query(
|
|
tenant_id=tenant_id,
|
|
status=StatusEnum.VALID.value
|
|
)
|
|
|
|
assert len(result) == 3
|
|
mock_dialog_service.query.assert_called_once()
|
|
|
|
def test_dialog_delete_success(self, mock_dialog_service):
|
|
"""Test soft delete of dialog (status update)"""
|
|
dialog_ids = ["id1", "id2", "id3"]
|
|
dialog_list = [{"id": id, "status": StatusEnum.INVALID.value} for id in dialog_ids]
|
|
|
|
mock_dialog_service.update_many_by_id.return_value = True
|
|
result = mock_dialog_service.update_many_by_id(dialog_list)
|
|
|
|
assert result is True
|
|
mock_dialog_service.update_many_by_id.assert_called_once_with(dialog_list)
|
|
|
|
def test_dialog_with_knowledge_bases(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test dialog creation with knowledge base IDs"""
|
|
sample_dialog_data["kb_ids"] = ["kb1", "kb2", "kb3"]
|
|
|
|
mock_dialog_service.save.return_value = True
|
|
result = mock_dialog_service.save(**sample_dialog_data)
|
|
|
|
assert result is True
|
|
assert len(sample_dialog_data["kb_ids"]) == 3
|
|
|
|
def test_dialog_llm_settings_validation(self, sample_dialog_data):
|
|
"""Test LLM settings validation"""
|
|
llm_setting = sample_dialog_data["llm_setting"]
|
|
|
|
# Validate temperature range
|
|
assert 0 <= llm_setting["temperature"] <= 2
|
|
|
|
# Validate top_p range
|
|
assert 0 <= llm_setting["top_p"] <= 1
|
|
|
|
# Validate max_tokens is positive
|
|
assert llm_setting["max_tokens"] > 0
|
|
|
|
def test_dialog_prompt_config_validation(self, sample_dialog_data):
|
|
"""Test prompt configuration validation"""
|
|
prompt_config = sample_dialog_data["prompt_config"]
|
|
|
|
# Required fields should exist
|
|
assert "system" in prompt_config
|
|
assert "prologue" in prompt_config
|
|
assert "parameters" in prompt_config
|
|
assert "empty_response" in prompt_config
|
|
|
|
# Parameters should be a list
|
|
assert isinstance(prompt_config["parameters"], list)
|
|
|
|
def test_dialog_duplicate_name_handling(self, mock_dialog_service):
|
|
"""Test handling of duplicate dialog names"""
|
|
tenant_id = "test_tenant"
|
|
name = "Duplicate Dialog"
|
|
|
|
# First dialog with this name exists
|
|
mock_dialog_service.query.return_value = [Mock(name=name)]
|
|
|
|
existing = mock_dialog_service.query(tenant_id=tenant_id, name=name)
|
|
assert len(existing) > 0
|
|
|
|
def test_dialog_similarity_threshold_validation(self, sample_dialog_data):
|
|
"""Test similarity threshold validation"""
|
|
threshold = sample_dialog_data["similarity_threshold"]
|
|
|
|
# Should be between 0 and 1
|
|
assert 0 <= threshold <= 1
|
|
|
|
def test_dialog_vector_similarity_weight_validation(self, sample_dialog_data):
|
|
"""Test vector similarity weight validation"""
|
|
weight = sample_dialog_data["vector_similarity_weight"]
|
|
|
|
# Should be between 0 and 1
|
|
assert 0 <= weight <= 1
|
|
|
|
def test_dialog_top_n_validation(self, sample_dialog_data):
|
|
"""Test top_n parameter validation"""
|
|
top_n = sample_dialog_data["top_n"]
|
|
|
|
# Should be positive integer
|
|
assert isinstance(top_n, int)
|
|
assert top_n > 0
|
|
|
|
def test_dialog_top_k_validation(self, sample_dialog_data):
|
|
"""Test top_k parameter validation"""
|
|
top_k = sample_dialog_data["top_k"]
|
|
|
|
# Should be positive integer
|
|
assert isinstance(top_k, int)
|
|
assert top_k > 0
|
|
|
|
def test_dialog_status_enum_validation(self, sample_dialog_data):
|
|
"""Test status field uses valid enum values"""
|
|
status = sample_dialog_data["status"]
|
|
|
|
# Should be valid status enum value
|
|
assert status in [StatusEnum.VALID.value, StatusEnum.INVALID.value]
|
|
|
|
@pytest.mark.parametrize("invalid_kb_ids", [
|
|
None, # None value
|
|
"not_a_list", # String instead of list
|
|
123, # Integer instead of list
|
|
])
|
|
def test_dialog_invalid_kb_ids_type(self, invalid_kb_ids):
|
|
"""Test dialog creation with invalid kb_ids type"""
|
|
with pytest.raises(Exception):
|
|
if not isinstance(invalid_kb_ids, list):
|
|
raise Exception("kb_ids must be a list")
|
|
|
|
def test_dialog_empty_kb_ids_allowed(self, mock_dialog_service, sample_dialog_data):
|
|
"""Test dialog creation with empty kb_ids is allowed"""
|
|
sample_dialog_data["kb_ids"] = []
|
|
|
|
mock_dialog_service.save.return_value = True
|
|
result = mock_dialog_service.save(**sample_dialog_data)
|
|
|
|
assert result is True
|
|
|
|
def test_dialog_query_with_pagination(self, mock_dialog_service):
|
|
"""Test dialog listing with pagination"""
|
|
page = 1
|
|
page_size = 10
|
|
total = 25
|
|
|
|
mock_dialogs = [Mock() for _ in range(page_size)]
|
|
mock_dialog_service.get_by_tenant_ids.return_value = (mock_dialogs, total)
|
|
|
|
result, count = mock_dialog_service.get_by_tenant_ids(
|
|
["tenant1"], "user1", page, page_size, "create_time", True, "", None
|
|
)
|
|
|
|
assert len(result) == page_size
|
|
assert count == total
|
|
|
|
def test_dialog_search_by_keywords(self, mock_dialog_service):
|
|
"""Test dialog search with keywords"""
|
|
keywords = "test"
|
|
mock_dialogs = [Mock(name="test dialog 1"), Mock(name="test dialog 2")]
|
|
|
|
mock_dialog_service.get_by_tenant_ids.return_value = (mock_dialogs, 2)
|
|
|
|
result, count = mock_dialog_service.get_by_tenant_ids(
|
|
["tenant1"], "user1", 0, 0, "create_time", True, keywords, None
|
|
)
|
|
|
|
assert count == 2
|
|
|
|
def test_dialog_ordering(self, mock_dialog_service):
|
|
"""Test dialog ordering by different fields"""
|
|
order_fields = ["create_time", "update_time", "name"]
|
|
|
|
for field in order_fields:
|
|
mock_dialog_service.get_by_tenant_ids.return_value = ([], 0)
|
|
mock_dialog_service.get_by_tenant_ids(
|
|
["tenant1"], "user1", 0, 0, field, True, "", None
|
|
)
|
|
mock_dialog_service.get_by_tenant_ids.assert_called()
|