ragflow/test/unit_test/services/test_canvas_service.py
hsparks.codes 6be40f52b3 fix: Remove unused MagicMock imports to pass ruff linter
- Remove unused MagicMock import from test_dialog_service.py
- Remove unused MagicMock import from test_conversation_service.py
- Remove unused MagicMock import from test_canvas_service.py

Fixes CI linting errors (F401: imported but unused)
2025-12-02 10:39:38 +01:00

389 lines
14 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import json
from unittest.mock import Mock, patch, AsyncMock
from common.misc_utils import get_uuid
class TestCanvasService:
"""Comprehensive unit tests for Canvas/Agent Service"""
@pytest.fixture
def mock_canvas_service(self):
"""Create a mock UserCanvasService for testing"""
with patch('api.db.services.canvas_service.UserCanvasService') as mock:
yield mock
@pytest.fixture
def sample_canvas_data(self):
"""Sample canvas/agent data for testing"""
return {
"id": get_uuid(),
"user_id": "test_user_123",
"title": "Test Agent",
"description": "A test agent workflow",
"avatar": "",
"canvas_type": "agent",
"canvas_category": "agent_canvas",
"permission": "me",
"dsl": {
"components": [
{
"id": "comp1",
"type": "LLM",
"config": {
"model": "gpt-4",
"temperature": 0.7
}
},
{
"id": "comp2",
"type": "Retrieval",
"config": {
"kb_ids": ["kb1", "kb2"]
}
}
],
"edges": [
{"source": "comp1", "target": "comp2"}
]
}
}
def test_canvas_creation_success(self, mock_canvas_service, sample_canvas_data):
"""Test successful canvas creation"""
mock_canvas_service.save.return_value = True
result = mock_canvas_service.save(**sample_canvas_data)
assert result is True
mock_canvas_service.save.assert_called_once_with(**sample_canvas_data)
def test_canvas_creation_with_duplicate_title(self, mock_canvas_service):
"""Test canvas creation with duplicate title"""
user_id = "user123"
title = "Duplicate Agent"
mock_canvas_service.query.return_value = [Mock(title=title)]
existing = mock_canvas_service.query(user_id=user_id, title=title)
assert len(existing) > 0
def test_canvas_get_by_id_success(self, mock_canvas_service, sample_canvas_data):
"""Test retrieving canvas by ID"""
canvas_id = sample_canvas_data["id"]
mock_canvas = Mock()
mock_canvas.to_dict.return_value = sample_canvas_data
mock_canvas_service.get_by_canvas_id.return_value = (True, sample_canvas_data)
exists, canvas = mock_canvas_service.get_by_canvas_id(canvas_id)
assert exists is True
assert canvas == sample_canvas_data
def test_canvas_get_by_id_not_found(self, mock_canvas_service):
"""Test retrieving non-existent canvas"""
mock_canvas_service.get_by_canvas_id.return_value = (False, None)
exists, canvas = mock_canvas_service.get_by_canvas_id("nonexistent_id")
assert exists is False
assert canvas is None
def test_canvas_update_success(self, mock_canvas_service, sample_canvas_data):
"""Test successful canvas update"""
canvas_id = sample_canvas_data["id"]
update_data = {"title": "Updated Agent Title"}
mock_canvas_service.update_by_id.return_value = True
result = mock_canvas_service.update_by_id(canvas_id, update_data)
assert result is True
def test_canvas_delete_success(self, mock_canvas_service):
"""Test canvas deletion"""
canvas_id = get_uuid()
mock_canvas_service.delete_by_id.return_value = True
result = mock_canvas_service.delete_by_id(canvas_id)
assert result is True
mock_canvas_service.delete_by_id.assert_called_once_with(canvas_id)
def test_canvas_dsl_structure_validation(self, sample_canvas_data):
"""Test DSL structure validation"""
dsl = sample_canvas_data["dsl"]
assert "components" in dsl
assert "edges" in dsl
assert isinstance(dsl["components"], list)
assert isinstance(dsl["edges"], list)
def test_canvas_component_validation(self, sample_canvas_data):
"""Test component structure validation"""
components = sample_canvas_data["dsl"]["components"]
for comp in components:
assert "id" in comp
assert "type" in comp
assert "config" in comp
def test_canvas_edge_validation(self, sample_canvas_data):
"""Test edge structure validation"""
edges = sample_canvas_data["dsl"]["edges"]
for edge in edges:
assert "source" in edge
assert "target" in edge
def test_canvas_accessible_by_owner(self, mock_canvas_service):
"""Test canvas accessibility check for owner"""
canvas_id = get_uuid()
user_id = "user123"
mock_canvas_service.accessible.return_value = True
result = mock_canvas_service.accessible(canvas_id, user_id)
assert result is True
def test_canvas_not_accessible_by_non_owner(self, mock_canvas_service):
"""Test canvas accessibility check for non-owner"""
canvas_id = get_uuid()
user_id = "other_user"
mock_canvas_service.accessible.return_value = False
result = mock_canvas_service.accessible(canvas_id, user_id)
assert result is False
def test_canvas_permission_me(self, sample_canvas_data):
"""Test canvas with 'me' permission"""
assert sample_canvas_data["permission"] == "me"
def test_canvas_permission_team(self, sample_canvas_data):
"""Test canvas with 'team' permission"""
sample_canvas_data["permission"] = "team"
assert sample_canvas_data["permission"] == "team"
def test_canvas_category_agent(self, sample_canvas_data):
"""Test canvas category as agent_canvas"""
assert sample_canvas_data["canvas_category"] == "agent_canvas"
def test_canvas_category_dataflow(self, sample_canvas_data):
"""Test canvas category as dataflow_canvas"""
sample_canvas_data["canvas_category"] = "dataflow_canvas"
assert sample_canvas_data["canvas_category"] == "dataflow_canvas"
def test_canvas_dsl_serialization(self, sample_canvas_data):
"""Test DSL JSON serialization"""
dsl = sample_canvas_data["dsl"]
dsl_json = json.dumps(dsl)
assert isinstance(dsl_json, str)
# Deserialize back
dsl_parsed = json.loads(dsl_json)
assert dsl_parsed == dsl
def test_canvas_with_llm_component(self, sample_canvas_data):
"""Test canvas with LLM component"""
llm_comp = sample_canvas_data["dsl"]["components"][0]
assert llm_comp["type"] == "LLM"
assert "model" in llm_comp["config"]
assert "temperature" in llm_comp["config"]
def test_canvas_with_retrieval_component(self, sample_canvas_data):
"""Test canvas with Retrieval component"""
retrieval_comp = sample_canvas_data["dsl"]["components"][1]
assert retrieval_comp["type"] == "Retrieval"
assert "kb_ids" in retrieval_comp["config"]
def test_canvas_component_connection(self, sample_canvas_data):
"""Test component connections via edges"""
edges = sample_canvas_data["dsl"]["edges"]
components = sample_canvas_data["dsl"]["components"]
comp_ids = {c["id"] for c in components}
for edge in edges:
assert edge["source"] in comp_ids
assert edge["target"] in comp_ids
def test_canvas_empty_dsl(self):
"""Test canvas with empty DSL"""
empty_dsl = {
"components": [],
"edges": []
}
assert len(empty_dsl["components"]) == 0
assert len(empty_dsl["edges"]) == 0
def test_canvas_complex_workflow(self):
"""Test canvas with complex multi-component workflow"""
complex_dsl = {
"components": [
{"id": "input", "type": "Input", "config": {}},
{"id": "llm1", "type": "LLM", "config": {"model": "gpt-4"}},
{"id": "retrieval", "type": "Retrieval", "config": {}},
{"id": "llm2", "type": "LLM", "config": {"model": "gpt-3.5"}},
{"id": "output", "type": "Output", "config": {}}
],
"edges": [
{"source": "input", "target": "llm1"},
{"source": "llm1", "target": "retrieval"},
{"source": "retrieval", "target": "llm2"},
{"source": "llm2", "target": "output"}
]
}
assert len(complex_dsl["components"]) == 5
assert len(complex_dsl["edges"]) == 4
def test_canvas_version_creation(self, mock_canvas_service):
"""Test canvas version creation"""
with patch('api.db.services.user_canvas_version.UserCanvasVersionService') as mock_version:
canvas_id = get_uuid()
dsl = {"components": [], "edges": []}
title = "Version 1"
mock_version.insert.return_value = True
result = mock_version.insert(
user_canvas_id=canvas_id,
dsl=dsl,
title=title
)
assert result is True
def test_canvas_list_by_user(self, mock_canvas_service):
"""Test listing canvases by user ID"""
user_id = "user123"
mock_canvases = [Mock() for _ in range(3)]
mock_canvas_service.query.return_value = mock_canvases
result = mock_canvas_service.query(user_id=user_id)
assert len(result) == 3
def test_canvas_list_by_category(self, mock_canvas_service):
"""Test listing canvases by category"""
user_id = "user123"
category = "agent_canvas"
mock_canvases = [Mock() for _ in range(2)]
mock_canvas_service.query.return_value = mock_canvases
result = mock_canvas_service.query(
user_id=user_id,
canvas_category=category
)
assert len(result) == 2
@pytest.mark.asyncio
async def test_canvas_run_execution(self):
"""Test canvas run execution"""
with patch('agent.canvas.Canvas') as MockCanvas:
mock_canvas = MockCanvas.return_value
mock_canvas.run = AsyncMock()
mock_canvas.run.return_value = AsyncMock()
# Simulate async iteration
async def async_gen():
yield {"content": "Response 1"}
yield {"content": "Response 2"}
mock_canvas.run.return_value = async_gen()
results = []
async for result in mock_canvas.run(query="test", files=[], user_id="user123"):
results.append(result)
assert len(results) == 2
def test_canvas_reset_functionality(self):
"""Test canvas reset functionality"""
with patch('agent.canvas.Canvas') as MockCanvas:
mock_canvas = MockCanvas.return_value
mock_canvas.reset = Mock()
mock_canvas.reset()
mock_canvas.reset.assert_called_once()
def test_canvas_component_input_form(self):
"""Test getting component input form"""
with patch('agent.canvas.Canvas') as MockCanvas:
mock_canvas = MockCanvas.return_value
mock_canvas.get_component_input_form = Mock(return_value={
"fields": [
{"name": "query", "type": "text", "required": True}
]
})
form = mock_canvas.get_component_input_form("comp1")
assert "fields" in form
assert len(form["fields"]) > 0
def test_canvas_debug_mode(self):
"""Test canvas debug mode execution"""
with patch('agent.canvas.Canvas') as MockCanvas:
mock_canvas = MockCanvas.return_value
component = Mock()
component.invoke = Mock()
component.output = Mock(return_value={"result": "debug output"})
mock_canvas.get_component = Mock(return_value={"obj": component})
comp_data = mock_canvas.get_component("comp1")
comp_data["obj"].invoke(param1="value1")
output = comp_data["obj"].output()
assert "result" in output
def test_canvas_title_length_validation(self):
"""Test canvas title length validation"""
long_title = "a" * 300
if len(long_title) > 255:
with pytest.raises(Exception):
raise Exception(f"Canvas title length {len(long_title)} exceeds 255")
@pytest.mark.parametrize("canvas_type", [
"agent",
"workflow",
"pipeline",
"custom"
])
def test_canvas_different_types(self, canvas_type, sample_canvas_data):
"""Test canvas with different types"""
sample_canvas_data["canvas_type"] = canvas_type
assert sample_canvas_data["canvas_type"] == canvas_type
def test_canvas_avatar_base64(self, sample_canvas_data):
"""Test canvas with base64 avatar"""
sample_canvas_data["avatar"] = "data:image/png;base64,iVBORw0KGgoAAAANS..."
assert sample_canvas_data["avatar"].startswith("data:image/")
def test_canvas_batch_delete(self, mock_canvas_service):
"""Test batch deletion of canvases"""
canvas_ids = [get_uuid() for _ in range(3)]
for canvas_id in canvas_ids:
mock_canvas_service.delete_by_id.return_value = True
result = mock_canvas_service.delete_by_id(canvas_id)
assert result is True