Format entire codebase with ruff and add type hints across all modules: - Apply ruff formatting to all Python files (121 files, 17K insertions) - Add type hints to function signatures throughout lightrag core and API - Update test suite with improved type annotations and docstrings - Add pyrightconfig.json for static type checking configuration - Create prompt_optimized.py and test_extraction_prompt_ab.py test files - Update ruff.toml and .gitignore for improved linting configuration - Standardize code style across examples, reproduce scripts, and utilities
854 lines
28 KiB
Python
854 lines
28 KiB
Python
"""
|
|
LightRAG Ollama Compatibility Interface Test Script
|
|
|
|
This script tests the LightRAG's Ollama compatibility interface, including:
|
|
1. Basic functionality tests (streaming and non-streaming responses)
|
|
2. Query mode tests (local, global, naive, hybrid)
|
|
3. Error handling tests (including streaming and non-streaming scenarios)
|
|
|
|
All responses use the JSON Lines format, complying with the Ollama API specification.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import time
|
|
from collections.abc import Callable
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime
|
|
from enum import Enum, auto
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
|
|
class ErrorCode(Enum):
|
|
"""Error codes for MCP errors"""
|
|
|
|
InvalidRequest = auto()
|
|
InternalError = auto()
|
|
|
|
|
|
class McpError(Exception):
|
|
"""Base exception class for MCP errors"""
|
|
|
|
def __init__(self, code: ErrorCode, message: str):
|
|
self.code = code
|
|
self.message = message
|
|
super().__init__(message)
|
|
|
|
|
|
DEFAULT_CONFIG = {
|
|
'server': {
|
|
'host': 'localhost',
|
|
'port': 9621,
|
|
'model': 'lightrag:latest',
|
|
'timeout': 300,
|
|
'max_retries': 1,
|
|
'retry_delay': 1,
|
|
},
|
|
'test_cases': {
|
|
'basic': {'query': '唐僧有几个徒弟'},
|
|
'generate': {'query': '电视剧西游记导演是谁'},
|
|
},
|
|
}
|
|
|
|
# Example conversation history for testing
|
|
EXAMPLE_CONVERSATION = [
|
|
{'role': 'user', 'content': '你好'},
|
|
{'role': 'assistant', 'content': '你好!我是一个AI助手,很高兴为你服务。'},
|
|
{'role': 'user', 'content': 'Who are you?'},
|
|
{'role': 'assistant', 'content': "I'm a Knowledge base query assistant."},
|
|
]
|
|
|
|
|
|
class OutputControl:
|
|
"""Output control class, manages the verbosity of test output"""
|
|
|
|
_verbose: bool = False
|
|
|
|
@classmethod
|
|
def set_verbose(cls, verbose: bool) -> None:
|
|
cls._verbose = verbose
|
|
|
|
@classmethod
|
|
def is_verbose(cls) -> bool:
|
|
return cls._verbose
|
|
|
|
|
|
@dataclass
|
|
class ExecutionResult:
|
|
"""Test execution result data class"""
|
|
|
|
name: str
|
|
success: bool
|
|
duration: float
|
|
error: str | None = None
|
|
timestamp: str = ''
|
|
|
|
def __post_init__(self):
|
|
if not self.timestamp:
|
|
self.timestamp = datetime.now().isoformat()
|
|
|
|
|
|
class ExecutionStats:
|
|
"""Test execution statistics"""
|
|
|
|
def __init__(self):
|
|
self.results: list[ExecutionResult] = []
|
|
self.start_time = datetime.now()
|
|
|
|
def add_result(self, result: ExecutionResult):
|
|
self.results.append(result)
|
|
|
|
def export_results(self, path: str = 'test_results.json'):
|
|
"""Export test results to a JSON file
|
|
Args:
|
|
path: Output file path
|
|
"""
|
|
results_data = {
|
|
'start_time': self.start_time.isoformat(),
|
|
'end_time': datetime.now().isoformat(),
|
|
'results': [asdict(r) for r in self.results],
|
|
'summary': {
|
|
'total': len(self.results),
|
|
'passed': sum(1 for r in self.results if r.success),
|
|
'failed': sum(1 for r in self.results if not r.success),
|
|
'total_duration': sum(r.duration for r in self.results),
|
|
},
|
|
}
|
|
|
|
with open(path, 'w', encoding='utf-8') as f:
|
|
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
|
print(f'\nTest results saved to: {path}')
|
|
|
|
def print_summary(self):
|
|
total = len(self.results)
|
|
passed = sum(1 for r in self.results if r.success)
|
|
failed = total - passed
|
|
duration = sum(r.duration for r in self.results)
|
|
|
|
print('\n=== Test Summary ===')
|
|
print(f'Start time: {self.start_time.strftime("%Y-%m-%d %H:%M:%S")}')
|
|
print(f'Total duration: {duration:.2f} seconds')
|
|
print(f'Total tests: {total}')
|
|
print(f'Passed: {passed}')
|
|
print(f'Failed: {failed}')
|
|
|
|
if failed > 0:
|
|
print('\nFailed tests:')
|
|
for result in self.results:
|
|
if not result.success:
|
|
print(f'- {result.name}: {result.error}')
|
|
|
|
|
|
def make_request(url: str, data: dict[str, Any], stream: bool = False, check_status: bool = True) -> requests.Response:
|
|
"""Send an HTTP request with retry mechanism
|
|
Args:
|
|
url: Request URL
|
|
data: Request data
|
|
stream: Whether to use streaming response
|
|
check_status: Whether to check HTTP status code (default: True)
|
|
Returns:
|
|
requests.Response: Response object
|
|
|
|
Raises:
|
|
requests.exceptions.RequestException: Request failed after all retries
|
|
requests.exceptions.HTTPError: HTTP status code is not 200 (when check_status is True)
|
|
"""
|
|
server_config = CONFIG['server']
|
|
max_retries = server_config['max_retries']
|
|
retry_delay = server_config['retry_delay']
|
|
timeout = server_config['timeout']
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = requests.post(url, json=data, stream=stream, timeout=timeout)
|
|
if check_status and response.status_code != 200:
|
|
response.raise_for_status()
|
|
return response
|
|
except requests.exceptions.RequestException as e:
|
|
if attempt == max_retries - 1: # Last retry
|
|
raise
|
|
print(f'\nRequest failed, retrying in {retry_delay} seconds: {e!s}')
|
|
time.sleep(retry_delay)
|
|
raise RuntimeError('Max retries exceeded')
|
|
|
|
|
|
def load_config() -> dict[str, Any]:
|
|
"""Load configuration file
|
|
|
|
First try to load from config.json in the current directory,
|
|
if it doesn't exist, use the default configuration
|
|
Returns:
|
|
Configuration dictionary
|
|
"""
|
|
config_path = Path('config.json')
|
|
if config_path.exists():
|
|
with open(config_path, encoding='utf-8') as f:
|
|
return json.load(f)
|
|
return DEFAULT_CONFIG
|
|
|
|
|
|
def print_json_response(data: dict[str, Any], title: str = '', indent: int = 2) -> None:
|
|
"""Format and print JSON response data
|
|
Args:
|
|
data: Data dictionary to print
|
|
title: Title to print
|
|
indent: Number of spaces for JSON indentation
|
|
"""
|
|
if OutputControl.is_verbose():
|
|
if title:
|
|
print(f'\n=== {title} ===')
|
|
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
|
|
|
|
|
# Global configuration
|
|
CONFIG = load_config()
|
|
|
|
|
|
def get_base_url(endpoint: str = 'chat') -> str:
|
|
"""Return the base URL for specified endpoint
|
|
Args:
|
|
endpoint: API endpoint name (chat or generate)
|
|
Returns:
|
|
Complete URL for the endpoint
|
|
"""
|
|
server = CONFIG['server']
|
|
return f'http://{server["host"]}:{server["port"]}/api/{endpoint}'
|
|
|
|
|
|
def create_chat_request_data(
|
|
content: str,
|
|
stream: bool = False,
|
|
model: str | None = None,
|
|
conversation_history: list[dict[str, str]] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Create chat request data
|
|
Args:
|
|
content: User message content
|
|
stream: Whether to use streaming response
|
|
model: Model name
|
|
conversation_history: List of previous conversation messages
|
|
history_turns: Number of history turns to include
|
|
Returns:
|
|
Dictionary containing complete chat request data
|
|
"""
|
|
messages = conversation_history or []
|
|
messages.append({'role': 'user', 'content': content})
|
|
|
|
return {
|
|
'model': model or CONFIG['server']['model'],
|
|
'messages': messages,
|
|
'stream': stream,
|
|
}
|
|
|
|
|
|
def create_generate_request_data(
|
|
prompt: str,
|
|
system: str | None = None,
|
|
stream: bool = False,
|
|
model: str | None = None,
|
|
options: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Create generate request data
|
|
Args:
|
|
prompt: Generation prompt
|
|
system: System prompt
|
|
stream: Whether to use streaming response
|
|
model: Model name
|
|
options: Additional options
|
|
Returns:
|
|
Dictionary containing complete generate request data
|
|
"""
|
|
data = {
|
|
'model': model or CONFIG['server']['model'],
|
|
'prompt': prompt,
|
|
'stream': stream,
|
|
}
|
|
if system:
|
|
data['system'] = system
|
|
if options:
|
|
data['options'] = options
|
|
return data
|
|
|
|
|
|
# Global test statistics
|
|
STATS = ExecutionStats()
|
|
|
|
|
|
def run_test(func: Callable, name: str) -> None:
|
|
"""Run a test and record the results
|
|
Args:
|
|
func: Test function
|
|
name: Test name
|
|
"""
|
|
start_time = time.time()
|
|
try:
|
|
func()
|
|
duration = time.time() - start_time
|
|
STATS.add_result(ExecutionResult(name, True, duration))
|
|
except Exception as e:
|
|
duration = time.time() - start_time
|
|
STATS.add_result(ExecutionResult(name, False, duration, str(e)))
|
|
raise
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_non_stream_chat() -> None:
|
|
"""Test non-streaming call to /api/chat endpoint"""
|
|
url = get_base_url()
|
|
|
|
# Send request with conversation history
|
|
data = create_chat_request_data(
|
|
CONFIG['test_cases']['basic']['query'],
|
|
stream=False,
|
|
conversation_history=EXAMPLE_CONVERSATION,
|
|
)
|
|
response = make_request(url, data)
|
|
|
|
# Print response
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Non-streaming call response ===')
|
|
response_json = response.json()
|
|
|
|
# Print response content
|
|
print_json_response(
|
|
{'model': response_json['model'], 'message': response_json['message']},
|
|
'Response content',
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_stream_chat() -> None:
|
|
"""Test streaming call to /api/chat endpoint
|
|
|
|
Use JSON Lines format to process streaming responses, each line is a complete JSON object.
|
|
Response format:
|
|
{
|
|
"model": "lightrag:latest",
|
|
"created_at": "2024-01-15T00:00:00Z",
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Partial response content",
|
|
"images": null
|
|
},
|
|
"done": false
|
|
}
|
|
|
|
The last message will contain performance statistics, with done set to true.
|
|
"""
|
|
url = get_base_url()
|
|
|
|
# Send request with conversation history
|
|
data = create_chat_request_data(
|
|
CONFIG['test_cases']['basic']['query'],
|
|
stream=True,
|
|
conversation_history=EXAMPLE_CONVERSATION,
|
|
)
|
|
response = make_request(url, data, stream=True)
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Streaming call response ===')
|
|
output_buffer = []
|
|
try:
|
|
for line in response.iter_lines():
|
|
if line: # Skip empty lines
|
|
try:
|
|
# Decode and parse JSON
|
|
data = json.loads(line.decode('utf-8'))
|
|
if data.get('done', True): # If it's the completion marker
|
|
if 'total_duration' in data: # Final performance statistics message
|
|
# print_json_response(data, "Performance statistics")
|
|
break
|
|
else: # Normal content message
|
|
message = data.get('message', {})
|
|
content = message.get('content', '')
|
|
if content: # Only collect non-empty content
|
|
output_buffer.append(content)
|
|
print(content, end='', flush=True) # Print content in real-time
|
|
except json.JSONDecodeError:
|
|
print('Error decoding JSON from response line')
|
|
finally:
|
|
response.close() # Ensure the response connection is closed
|
|
|
|
# Print a newline
|
|
print()
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_query_modes() -> None:
|
|
"""Test different query mode prefixes
|
|
|
|
Supported query modes:
|
|
- /local: Local retrieval mode, searches only in highly relevant documents
|
|
- /global: Global retrieval mode, searches across all documents
|
|
- /naive: Naive mode, does not use any optimization strategies
|
|
- /hybrid: Hybrid mode (default), combines multiple strategies
|
|
- /mix: Mix mode
|
|
|
|
Each mode will return responses in the same format, but with different retrieval strategies.
|
|
"""
|
|
url = get_base_url()
|
|
modes = ['local', 'global', 'naive', 'hybrid', 'mix']
|
|
|
|
for mode in modes:
|
|
if OutputControl.is_verbose():
|
|
print(f'\n=== Testing /{mode} mode ===')
|
|
data = create_chat_request_data(f'/{mode} {CONFIG["test_cases"]["basic"]["query"]}', stream=False)
|
|
|
|
# Send request
|
|
response = make_request(url, data)
|
|
response_json = response.json()
|
|
|
|
# Print response content
|
|
print_json_response({'model': response_json['model'], 'message': response_json['message']})
|
|
|
|
|
|
def create_error_test_data(error_type: str) -> dict[str, Any]:
|
|
"""Create request data for error testing
|
|
Args:
|
|
error_type: Error type, supported:
|
|
- empty_messages: Empty message list
|
|
- invalid_role: Invalid role field
|
|
- missing_content: Missing content field
|
|
|
|
Returns:
|
|
Request dictionary containing error data
|
|
"""
|
|
error_data = {
|
|
'empty_messages': {'model': 'lightrag:latest', 'messages': [], 'stream': True},
|
|
'invalid_role': {
|
|
'model': 'lightrag:latest',
|
|
'messages': [{'invalid_role': 'user', 'content': 'Test message'}],
|
|
'stream': True,
|
|
},
|
|
'missing_content': {
|
|
'model': 'lightrag:latest',
|
|
'messages': [{'role': 'user'}],
|
|
'stream': True,
|
|
},
|
|
}
|
|
return error_data.get(error_type, error_data['empty_messages'])
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_stream_error_handling() -> None:
|
|
"""Test error handling for streaming responses
|
|
|
|
Test scenarios:
|
|
1. Empty message list
|
|
2. Message format error (missing required fields)
|
|
|
|
Error responses should be returned immediately without establishing a streaming connection.
|
|
The status code should be 4xx, and detailed error information should be returned.
|
|
"""
|
|
url = get_base_url()
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Testing streaming response error handling ===')
|
|
|
|
# Test empty message list
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing empty message list (streaming) ---')
|
|
data = create_error_test_data('empty_messages')
|
|
response = make_request(url, data, stream=True, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
if response.status_code != 200:
|
|
print_json_response(response.json(), 'Error message')
|
|
response.close()
|
|
|
|
# Test invalid role field
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing invalid role field (streaming) ---')
|
|
data = create_error_test_data('invalid_role')
|
|
response = make_request(url, data, stream=True, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
if response.status_code != 200:
|
|
print_json_response(response.json(), 'Error message')
|
|
response.close()
|
|
|
|
# Test missing content field
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing missing content field (streaming) ---')
|
|
data = create_error_test_data('missing_content')
|
|
response = make_request(url, data, stream=True, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
if response.status_code != 200:
|
|
print_json_response(response.json(), 'Error message')
|
|
response.close()
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_error_handling() -> None:
|
|
"""Test error handling for non-streaming responses
|
|
|
|
Test scenarios:
|
|
1. Empty message list
|
|
2. Message format error (missing required fields)
|
|
|
|
Error response format:
|
|
{
|
|
"detail": "Error description"
|
|
}
|
|
|
|
All errors should return appropriate HTTP status codes and clear error messages.
|
|
"""
|
|
url = get_base_url()
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Testing error handling ===')
|
|
|
|
# Test empty message list
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing empty message list ---')
|
|
data = create_error_test_data('empty_messages')
|
|
data['stream'] = False # Change to non-streaming mode
|
|
response = make_request(url, data, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
print_json_response(response.json(), 'Error message')
|
|
|
|
# Test invalid role field
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing invalid role field ---')
|
|
data = create_error_test_data('invalid_role')
|
|
data['stream'] = False # Change to non-streaming mode
|
|
response = make_request(url, data, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
print_json_response(response.json(), 'Error message')
|
|
|
|
# Test missing content field
|
|
if OutputControl.is_verbose():
|
|
print('\n--- Testing missing content field ---')
|
|
data = create_error_test_data('missing_content')
|
|
data['stream'] = False # Change to non-streaming mode
|
|
response = make_request(url, data, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
print_json_response(response.json(), 'Error message')
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_non_stream_generate() -> None:
|
|
"""Test non-streaming call to /api/generate endpoint"""
|
|
url = get_base_url('generate')
|
|
data = create_generate_request_data(CONFIG['test_cases']['generate']['query'], stream=False)
|
|
|
|
# Send request
|
|
response = make_request(url, data)
|
|
|
|
# Print response
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Non-streaming generate response ===')
|
|
response_json = response.json()
|
|
|
|
# Print response content
|
|
print(json.dumps(response_json, ensure_ascii=False, indent=2))
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_stream_generate() -> None:
|
|
"""Test streaming call to /api/generate endpoint"""
|
|
url = get_base_url('generate')
|
|
data = create_generate_request_data(CONFIG['test_cases']['generate']['query'], stream=True)
|
|
|
|
# Send request and get streaming response
|
|
response = make_request(url, data, stream=True)
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Streaming generate response ===')
|
|
output_buffer = []
|
|
try:
|
|
for line in response.iter_lines():
|
|
if line: # Skip empty lines
|
|
try:
|
|
# Decode and parse JSON
|
|
data = json.loads(line.decode('utf-8'))
|
|
if data.get('done', True): # If it's the completion marker
|
|
if 'total_duration' in data: # Final performance statistics message
|
|
break
|
|
else: # Normal content message
|
|
content = data.get('response', '')
|
|
if content: # Only collect non-empty content
|
|
output_buffer.append(content)
|
|
print(content, end='', flush=True) # Print content in real-time
|
|
except json.JSONDecodeError:
|
|
print('Error decoding JSON from response line')
|
|
finally:
|
|
response.close() # Ensure the response connection is closed
|
|
|
|
# Print a newline
|
|
print()
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_generate_with_system() -> None:
|
|
"""Test generate with system prompt"""
|
|
url = get_base_url('generate')
|
|
data = create_generate_request_data(
|
|
CONFIG['test_cases']['generate']['query'],
|
|
system='你是一个知识渊博的助手',
|
|
stream=False,
|
|
)
|
|
|
|
# Send request
|
|
response = make_request(url, data)
|
|
|
|
# Print response
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Generate with system prompt response ===')
|
|
response_json = response.json()
|
|
|
|
# Print response content
|
|
print_json_response(
|
|
{
|
|
'model': response_json['model'],
|
|
'response': response_json['response'],
|
|
'done': response_json['done'],
|
|
},
|
|
'Response content',
|
|
)
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_generate_error_handling() -> None:
|
|
"""Test error handling for generate endpoint"""
|
|
url = get_base_url('generate')
|
|
|
|
# Test empty prompt
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Testing empty prompt ===')
|
|
data = create_generate_request_data('', stream=False)
|
|
response = make_request(url, data, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
print_json_response(response.json(), 'Error message')
|
|
|
|
# Test invalid options
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Testing invalid options ===')
|
|
data = create_generate_request_data(
|
|
CONFIG['test_cases']['basic']['query'],
|
|
options={'invalid_option': 'value'},
|
|
stream=False,
|
|
)
|
|
response = make_request(url, data, check_status=False)
|
|
print(f'Status code: {response.status_code}')
|
|
print_json_response(response.json(), 'Error message')
|
|
|
|
|
|
@pytest.mark.integration
|
|
@pytest.mark.requires_api
|
|
def test_generate_concurrent() -> None:
|
|
"""Test concurrent generate requests"""
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
|
|
import aiohttp
|
|
|
|
@asynccontextmanager
|
|
async def get_session():
|
|
async with aiohttp.ClientSession() as session:
|
|
yield session
|
|
|
|
async def make_request(session, prompt: str, request_id: int):
|
|
url = get_base_url('generate')
|
|
data = create_generate_request_data(prompt, stream=False)
|
|
try:
|
|
async with session.post(url, json=data) as response:
|
|
if response.status != 200:
|
|
error_msg = f'Request {request_id} failed with status {response.status}'
|
|
if OutputControl.is_verbose():
|
|
print(f'\n{error_msg}')
|
|
raise McpError(ErrorCode.InternalError, error_msg)
|
|
result = await response.json()
|
|
if 'error' in result:
|
|
error_msg = f'Request {request_id} returned error: {result["error"]}'
|
|
if OutputControl.is_verbose():
|
|
print(f'\n{error_msg}')
|
|
raise McpError(ErrorCode.InternalError, error_msg)
|
|
return result
|
|
except Exception as e:
|
|
error_msg = f'Request {request_id} failed: {e!s}'
|
|
if OutputControl.is_verbose():
|
|
print(f'\n{error_msg}')
|
|
raise McpError(ErrorCode.InternalError, error_msg) from e
|
|
|
|
async def run_concurrent_requests():
|
|
prompts = ['第一个问题', '第二个问题', '第三个问题', '第四个问题', '第五个问题']
|
|
|
|
async with get_session() as session:
|
|
tasks = [make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
success_results = []
|
|
error_messages = []
|
|
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
error_messages.append(f'Request {i + 1} failed: {result!s}')
|
|
else:
|
|
success_results.append((i + 1, result))
|
|
|
|
if error_messages:
|
|
for req_id, result in success_results:
|
|
if OutputControl.is_verbose():
|
|
print(f'\nRequest {req_id} succeeded:')
|
|
print_json_response(result)
|
|
|
|
error_summary = '\n'.join(error_messages)
|
|
raise McpError(
|
|
ErrorCode.InternalError,
|
|
f'Some concurrent requests failed:\n{error_summary}',
|
|
)
|
|
|
|
return results
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n=== Testing concurrent generate requests ===')
|
|
|
|
# Run concurrent requests
|
|
try:
|
|
results = asyncio.run(run_concurrent_requests())
|
|
# all success, print out results
|
|
for i, result in enumerate(results, 1):
|
|
print(f'\nRequest {i} result:')
|
|
print_json_response(result)
|
|
except McpError:
|
|
# error message already printed
|
|
raise
|
|
|
|
|
|
def get_test_cases() -> dict[str, Callable]:
|
|
"""Get all available test cases
|
|
Returns:
|
|
A dictionary mapping test names to test functions
|
|
"""
|
|
return {
|
|
'non_stream': test_non_stream_chat,
|
|
'stream': test_stream_chat,
|
|
'modes': test_query_modes,
|
|
'errors': test_error_handling,
|
|
'stream_errors': test_stream_error_handling,
|
|
'non_stream_generate': test_non_stream_generate,
|
|
'stream_generate': test_stream_generate,
|
|
'generate_with_system': test_generate_with_system,
|
|
'generate_errors': test_generate_error_handling,
|
|
'generate_concurrent': test_generate_concurrent,
|
|
}
|
|
|
|
|
|
def create_default_config():
|
|
"""Create a default configuration file"""
|
|
config_path = Path('config.json')
|
|
if not config_path.exists():
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
|
print(f'Default configuration file created: {config_path}')
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(
|
|
description='LightRAG Ollama Compatibility Interface Testing',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Configuration file (config.json):
|
|
{
|
|
"server": {
|
|
"host": "localhost", # Server address
|
|
"port": 9621, # Server port
|
|
"model": "lightrag:latest" # Default model name
|
|
},
|
|
"test_cases": {
|
|
"basic": {
|
|
"query": "Test query", # Basic query text
|
|
"stream_query": "Stream query" # Stream query text
|
|
}
|
|
}
|
|
}
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
'-q',
|
|
'--quiet',
|
|
action='store_true',
|
|
help='Silent mode, only display test result summary',
|
|
)
|
|
parser.add_argument(
|
|
'-a',
|
|
'--ask',
|
|
type=str,
|
|
help='Specify query content, which will override the query settings in the configuration file',
|
|
)
|
|
parser.add_argument('--init-config', action='store_true', help='Create default configuration file')
|
|
parser.add_argument(
|
|
'--output',
|
|
type=str,
|
|
default='',
|
|
help='Test result output file path, default is not to output to a file',
|
|
)
|
|
parser.add_argument(
|
|
'--tests',
|
|
nargs='+',
|
|
choices=[*list(get_test_cases().keys()), 'all'],
|
|
default=['all'],
|
|
help="Test cases to run, options: %(choices)s. Use 'all' to run all tests (except error tests)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# Set output mode
|
|
OutputControl.set_verbose(not args.quiet)
|
|
|
|
# If query content is specified, update the configuration
|
|
if args.ask:
|
|
CONFIG['test_cases']['basic']['query'] = args.ask
|
|
|
|
# If specified to create a configuration file
|
|
if args.init_config:
|
|
create_default_config()
|
|
exit(0)
|
|
|
|
test_cases = get_test_cases()
|
|
|
|
try:
|
|
if 'all' in args.tests:
|
|
# Run all tests except error handling tests
|
|
if OutputControl.is_verbose():
|
|
print('\n【Chat API Tests】')
|
|
run_test(test_non_stream_chat, 'Non-streaming Chat Test')
|
|
run_test(test_stream_chat, 'Streaming Chat Test')
|
|
run_test(test_query_modes, 'Chat Query Mode Test')
|
|
|
|
if OutputControl.is_verbose():
|
|
print('\n【Generate API Tests】')
|
|
run_test(test_non_stream_generate, 'Non-streaming Generate Test')
|
|
run_test(test_stream_generate, 'Streaming Generate Test')
|
|
run_test(test_generate_with_system, 'Generate with System Prompt Test')
|
|
run_test(test_generate_concurrent, 'Generate Concurrent Test')
|
|
else:
|
|
# Run specified tests
|
|
for test_name in args.tests:
|
|
if OutputControl.is_verbose():
|
|
print(f'\n【Running Test: {test_name}】')
|
|
run_test(test_cases[test_name], test_name)
|
|
except Exception as e:
|
|
print(f'\nAn error occurred: {e!s}')
|
|
finally:
|
|
# Print test statistics
|
|
STATS.print_summary()
|
|
# If an output file path is specified, export the results
|
|
if args.output:
|
|
STATS.export_results(args.output)
|