- Remove unused MagicMock import from test_dialog_service.py - Remove unused MagicMock import from test_conversation_service.py - Remove unused MagicMock import from test_canvas_service.py Fixes CI linting errors (F401: imported but unused)
389 lines
14 KiB
Python
389 lines
14 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import pytest
|
|
import json
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from common.misc_utils import get_uuid
|
|
|
|
|
|
class TestCanvasService:
|
|
"""Comprehensive unit tests for Canvas/Agent Service"""
|
|
|
|
@pytest.fixture
|
|
def mock_canvas_service(self):
|
|
"""Create a mock UserCanvasService for testing"""
|
|
with patch('api.db.services.canvas_service.UserCanvasService') as mock:
|
|
yield mock
|
|
|
|
@pytest.fixture
|
|
def sample_canvas_data(self):
|
|
"""Sample canvas/agent data for testing"""
|
|
return {
|
|
"id": get_uuid(),
|
|
"user_id": "test_user_123",
|
|
"title": "Test Agent",
|
|
"description": "A test agent workflow",
|
|
"avatar": "",
|
|
"canvas_type": "agent",
|
|
"canvas_category": "agent_canvas",
|
|
"permission": "me",
|
|
"dsl": {
|
|
"components": [
|
|
{
|
|
"id": "comp1",
|
|
"type": "LLM",
|
|
"config": {
|
|
"model": "gpt-4",
|
|
"temperature": 0.7
|
|
}
|
|
},
|
|
{
|
|
"id": "comp2",
|
|
"type": "Retrieval",
|
|
"config": {
|
|
"kb_ids": ["kb1", "kb2"]
|
|
}
|
|
}
|
|
],
|
|
"edges": [
|
|
{"source": "comp1", "target": "comp2"}
|
|
]
|
|
}
|
|
}
|
|
|
|
def test_canvas_creation_success(self, mock_canvas_service, sample_canvas_data):
|
|
"""Test successful canvas creation"""
|
|
mock_canvas_service.save.return_value = True
|
|
|
|
result = mock_canvas_service.save(**sample_canvas_data)
|
|
assert result is True
|
|
mock_canvas_service.save.assert_called_once_with(**sample_canvas_data)
|
|
|
|
def test_canvas_creation_with_duplicate_title(self, mock_canvas_service):
|
|
"""Test canvas creation with duplicate title"""
|
|
user_id = "user123"
|
|
title = "Duplicate Agent"
|
|
|
|
mock_canvas_service.query.return_value = [Mock(title=title)]
|
|
|
|
existing = mock_canvas_service.query(user_id=user_id, title=title)
|
|
assert len(existing) > 0
|
|
|
|
def test_canvas_get_by_id_success(self, mock_canvas_service, sample_canvas_data):
|
|
"""Test retrieving canvas by ID"""
|
|
canvas_id = sample_canvas_data["id"]
|
|
mock_canvas = Mock()
|
|
mock_canvas.to_dict.return_value = sample_canvas_data
|
|
|
|
mock_canvas_service.get_by_canvas_id.return_value = (True, sample_canvas_data)
|
|
|
|
exists, canvas = mock_canvas_service.get_by_canvas_id(canvas_id)
|
|
assert exists is True
|
|
assert canvas == sample_canvas_data
|
|
|
|
def test_canvas_get_by_id_not_found(self, mock_canvas_service):
|
|
"""Test retrieving non-existent canvas"""
|
|
mock_canvas_service.get_by_canvas_id.return_value = (False, None)
|
|
|
|
exists, canvas = mock_canvas_service.get_by_canvas_id("nonexistent_id")
|
|
assert exists is False
|
|
assert canvas is None
|
|
|
|
def test_canvas_update_success(self, mock_canvas_service, sample_canvas_data):
|
|
"""Test successful canvas update"""
|
|
canvas_id = sample_canvas_data["id"]
|
|
update_data = {"title": "Updated Agent Title"}
|
|
|
|
mock_canvas_service.update_by_id.return_value = True
|
|
result = mock_canvas_service.update_by_id(canvas_id, update_data)
|
|
|
|
assert result is True
|
|
|
|
def test_canvas_delete_success(self, mock_canvas_service):
|
|
"""Test canvas deletion"""
|
|
canvas_id = get_uuid()
|
|
|
|
mock_canvas_service.delete_by_id.return_value = True
|
|
result = mock_canvas_service.delete_by_id(canvas_id)
|
|
|
|
assert result is True
|
|
mock_canvas_service.delete_by_id.assert_called_once_with(canvas_id)
|
|
|
|
def test_canvas_dsl_structure_validation(self, sample_canvas_data):
|
|
"""Test DSL structure validation"""
|
|
dsl = sample_canvas_data["dsl"]
|
|
|
|
assert "components" in dsl
|
|
assert "edges" in dsl
|
|
assert isinstance(dsl["components"], list)
|
|
assert isinstance(dsl["edges"], list)
|
|
|
|
def test_canvas_component_validation(self, sample_canvas_data):
|
|
"""Test component structure validation"""
|
|
components = sample_canvas_data["dsl"]["components"]
|
|
|
|
for comp in components:
|
|
assert "id" in comp
|
|
assert "type" in comp
|
|
assert "config" in comp
|
|
|
|
def test_canvas_edge_validation(self, sample_canvas_data):
|
|
"""Test edge structure validation"""
|
|
edges = sample_canvas_data["dsl"]["edges"]
|
|
|
|
for edge in edges:
|
|
assert "source" in edge
|
|
assert "target" in edge
|
|
|
|
def test_canvas_accessible_by_owner(self, mock_canvas_service):
|
|
"""Test canvas accessibility check for owner"""
|
|
canvas_id = get_uuid()
|
|
user_id = "user123"
|
|
|
|
mock_canvas_service.accessible.return_value = True
|
|
result = mock_canvas_service.accessible(canvas_id, user_id)
|
|
|
|
assert result is True
|
|
|
|
def test_canvas_not_accessible_by_non_owner(self, mock_canvas_service):
|
|
"""Test canvas accessibility check for non-owner"""
|
|
canvas_id = get_uuid()
|
|
user_id = "other_user"
|
|
|
|
mock_canvas_service.accessible.return_value = False
|
|
result = mock_canvas_service.accessible(canvas_id, user_id)
|
|
|
|
assert result is False
|
|
|
|
def test_canvas_permission_me(self, sample_canvas_data):
|
|
"""Test canvas with 'me' permission"""
|
|
assert sample_canvas_data["permission"] == "me"
|
|
|
|
def test_canvas_permission_team(self, sample_canvas_data):
|
|
"""Test canvas with 'team' permission"""
|
|
sample_canvas_data["permission"] = "team"
|
|
assert sample_canvas_data["permission"] == "team"
|
|
|
|
def test_canvas_category_agent(self, sample_canvas_data):
|
|
"""Test canvas category as agent_canvas"""
|
|
assert sample_canvas_data["canvas_category"] == "agent_canvas"
|
|
|
|
def test_canvas_category_dataflow(self, sample_canvas_data):
|
|
"""Test canvas category as dataflow_canvas"""
|
|
sample_canvas_data["canvas_category"] = "dataflow_canvas"
|
|
assert sample_canvas_data["canvas_category"] == "dataflow_canvas"
|
|
|
|
def test_canvas_dsl_serialization(self, sample_canvas_data):
|
|
"""Test DSL JSON serialization"""
|
|
dsl = sample_canvas_data["dsl"]
|
|
dsl_json = json.dumps(dsl)
|
|
|
|
assert isinstance(dsl_json, str)
|
|
|
|
# Deserialize back
|
|
dsl_parsed = json.loads(dsl_json)
|
|
assert dsl_parsed == dsl
|
|
|
|
def test_canvas_with_llm_component(self, sample_canvas_data):
|
|
"""Test canvas with LLM component"""
|
|
llm_comp = sample_canvas_data["dsl"]["components"][0]
|
|
|
|
assert llm_comp["type"] == "LLM"
|
|
assert "model" in llm_comp["config"]
|
|
assert "temperature" in llm_comp["config"]
|
|
|
|
def test_canvas_with_retrieval_component(self, sample_canvas_data):
|
|
"""Test canvas with Retrieval component"""
|
|
retrieval_comp = sample_canvas_data["dsl"]["components"][1]
|
|
|
|
assert retrieval_comp["type"] == "Retrieval"
|
|
assert "kb_ids" in retrieval_comp["config"]
|
|
|
|
def test_canvas_component_connection(self, sample_canvas_data):
|
|
"""Test component connections via edges"""
|
|
edges = sample_canvas_data["dsl"]["edges"]
|
|
components = sample_canvas_data["dsl"]["components"]
|
|
|
|
comp_ids = {c["id"] for c in components}
|
|
|
|
for edge in edges:
|
|
assert edge["source"] in comp_ids
|
|
assert edge["target"] in comp_ids
|
|
|
|
def test_canvas_empty_dsl(self):
|
|
"""Test canvas with empty DSL"""
|
|
empty_dsl = {
|
|
"components": [],
|
|
"edges": []
|
|
}
|
|
|
|
assert len(empty_dsl["components"]) == 0
|
|
assert len(empty_dsl["edges"]) == 0
|
|
|
|
def test_canvas_complex_workflow(self):
|
|
"""Test canvas with complex multi-component workflow"""
|
|
complex_dsl = {
|
|
"components": [
|
|
{"id": "input", "type": "Input", "config": {}},
|
|
{"id": "llm1", "type": "LLM", "config": {"model": "gpt-4"}},
|
|
{"id": "retrieval", "type": "Retrieval", "config": {}},
|
|
{"id": "llm2", "type": "LLM", "config": {"model": "gpt-3.5"}},
|
|
{"id": "output", "type": "Output", "config": {}}
|
|
],
|
|
"edges": [
|
|
{"source": "input", "target": "llm1"},
|
|
{"source": "llm1", "target": "retrieval"},
|
|
{"source": "retrieval", "target": "llm2"},
|
|
{"source": "llm2", "target": "output"}
|
|
]
|
|
}
|
|
|
|
assert len(complex_dsl["components"]) == 5
|
|
assert len(complex_dsl["edges"]) == 4
|
|
|
|
def test_canvas_version_creation(self, mock_canvas_service):
|
|
"""Test canvas version creation"""
|
|
with patch('api.db.services.user_canvas_version.UserCanvasVersionService') as mock_version:
|
|
canvas_id = get_uuid()
|
|
dsl = {"components": [], "edges": []}
|
|
title = "Version 1"
|
|
|
|
mock_version.insert.return_value = True
|
|
result = mock_version.insert(
|
|
user_canvas_id=canvas_id,
|
|
dsl=dsl,
|
|
title=title
|
|
)
|
|
|
|
assert result is True
|
|
|
|
def test_canvas_list_by_user(self, mock_canvas_service):
|
|
"""Test listing canvases by user ID"""
|
|
user_id = "user123"
|
|
mock_canvases = [Mock() for _ in range(3)]
|
|
|
|
mock_canvas_service.query.return_value = mock_canvases
|
|
|
|
result = mock_canvas_service.query(user_id=user_id)
|
|
assert len(result) == 3
|
|
|
|
def test_canvas_list_by_category(self, mock_canvas_service):
|
|
"""Test listing canvases by category"""
|
|
user_id = "user123"
|
|
category = "agent_canvas"
|
|
mock_canvases = [Mock() for _ in range(2)]
|
|
|
|
mock_canvas_service.query.return_value = mock_canvases
|
|
|
|
result = mock_canvas_service.query(
|
|
user_id=user_id,
|
|
canvas_category=category
|
|
)
|
|
assert len(result) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_canvas_run_execution(self):
|
|
"""Test canvas run execution"""
|
|
with patch('agent.canvas.Canvas') as MockCanvas:
|
|
mock_canvas = MockCanvas.return_value
|
|
mock_canvas.run = AsyncMock()
|
|
mock_canvas.run.return_value = AsyncMock()
|
|
|
|
# Simulate async iteration
|
|
async def async_gen():
|
|
yield {"content": "Response 1"}
|
|
yield {"content": "Response 2"}
|
|
|
|
mock_canvas.run.return_value = async_gen()
|
|
|
|
results = []
|
|
async for result in mock_canvas.run(query="test", files=[], user_id="user123"):
|
|
results.append(result)
|
|
|
|
assert len(results) == 2
|
|
|
|
def test_canvas_reset_functionality(self):
|
|
"""Test canvas reset functionality"""
|
|
with patch('agent.canvas.Canvas') as MockCanvas:
|
|
mock_canvas = MockCanvas.return_value
|
|
mock_canvas.reset = Mock()
|
|
|
|
mock_canvas.reset()
|
|
mock_canvas.reset.assert_called_once()
|
|
|
|
def test_canvas_component_input_form(self):
|
|
"""Test getting component input form"""
|
|
with patch('agent.canvas.Canvas') as MockCanvas:
|
|
mock_canvas = MockCanvas.return_value
|
|
mock_canvas.get_component_input_form = Mock(return_value={
|
|
"fields": [
|
|
{"name": "query", "type": "text", "required": True}
|
|
]
|
|
})
|
|
|
|
form = mock_canvas.get_component_input_form("comp1")
|
|
assert "fields" in form
|
|
assert len(form["fields"]) > 0
|
|
|
|
def test_canvas_debug_mode(self):
|
|
"""Test canvas debug mode execution"""
|
|
with patch('agent.canvas.Canvas') as MockCanvas:
|
|
mock_canvas = MockCanvas.return_value
|
|
component = Mock()
|
|
component.invoke = Mock()
|
|
component.output = Mock(return_value={"result": "debug output"})
|
|
|
|
mock_canvas.get_component = Mock(return_value={"obj": component})
|
|
|
|
comp_data = mock_canvas.get_component("comp1")
|
|
comp_data["obj"].invoke(param1="value1")
|
|
output = comp_data["obj"].output()
|
|
|
|
assert "result" in output
|
|
|
|
def test_canvas_title_length_validation(self):
|
|
"""Test canvas title length validation"""
|
|
long_title = "a" * 300
|
|
|
|
if len(long_title) > 255:
|
|
with pytest.raises(Exception):
|
|
raise Exception(f"Canvas title length {len(long_title)} exceeds 255")
|
|
|
|
@pytest.mark.parametrize("canvas_type", [
|
|
"agent",
|
|
"workflow",
|
|
"pipeline",
|
|
"custom"
|
|
])
|
|
def test_canvas_different_types(self, canvas_type, sample_canvas_data):
|
|
"""Test canvas with different types"""
|
|
sample_canvas_data["canvas_type"] = canvas_type
|
|
assert sample_canvas_data["canvas_type"] == canvas_type
|
|
|
|
def test_canvas_avatar_base64(self, sample_canvas_data):
|
|
"""Test canvas with base64 avatar"""
|
|
sample_canvas_data["avatar"] = "data:image/png;base64,iVBORw0KGgoAAAANS..."
|
|
assert sample_canvas_data["avatar"].startswith("data:image/")
|
|
|
|
def test_canvas_batch_delete(self, mock_canvas_service):
|
|
"""Test batch deletion of canvases"""
|
|
canvas_ids = [get_uuid() for _ in range(3)]
|
|
|
|
for canvas_id in canvas_ids:
|
|
mock_canvas_service.delete_by_id.return_value = True
|
|
result = mock_canvas_service.delete_by_id(canvas_id)
|
|
assert result is True
|