diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e7d00f4a..1c0db0d8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Offline Unit Tests +name: Tests on: push: diff --git a/tests/conftest.py b/tests/conftest.py index 41db438d..09769fd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,21 @@ This file provides command-line options and fixtures for test configuration. import pytest +def pytest_configure(config): + """Register custom markers for LightRAG tests.""" + config.addinivalue_line( + "markers", "offline: marks tests as offline (no external dependencies)" + ) + config.addinivalue_line( + "markers", + "integration: marks tests requiring external services (skipped by default)", + ) + config.addinivalue_line("markers", "requires_db: marks tests requiring database") + config.addinivalue_line( + "markers", "requires_api: marks tests requiring LightRAG API server" + ) + + def pytest_addoption(parser): """Add custom command-line options for LightRAG tests.""" @@ -32,6 +47,32 @@ def pytest_addoption(parser): help="Number of parallel workers for stress tests (default: 3)", ) + parser.addoption( + "--run-integration", + action="store_true", + default=False, + help="Run integration tests that require external services (database, API server, etc.)", + ) + + +def pytest_collection_modifyitems(config, items): + """Modify test collection to skip integration tests by default. + + Integration tests are skipped unless --run-integration flag is provided. + This allows running offline tests quickly without needing external services. + """ + if config.getoption("--run-integration"): + # If --run-integration is specified, run all tests + return + + skip_integration = pytest.mark.skip( + reason="Requires external services(DB/API), use --run-integration to run" + ) + + for item in items: + if "integration" in item.keywords: + item.add_marker(skip_integration) + @pytest.fixture(scope="session") def keep_test_artifacts(request): @@ -83,3 +124,20 @@ def parallel_workers(request): # Fall back to environment variable return int(os.getenv("LIGHTRAG_TEST_WORKERS", "3")) + + +@pytest.fixture(scope="session") +def run_integration_tests(request): + """ + Fixture to determine whether to run integration tests. + + Priority: CLI option > Environment variable > Default (False) + """ + import os + + # Check CLI option first + if request.config.getoption("--run-integration"): + return True + + # Fall back to environment variable + return os.getenv("LIGHTRAG_RUN_INTEGRATION", "false").lower() == "true" diff --git a/tests/test_aquery_data_endpoint.py b/tests/test_aquery_data_endpoint.py index 8845cb79..4866c779 100644 --- a/tests/test_aquery_data_endpoint.py +++ b/tests/test_aquery_data_endpoint.py @@ -9,6 +9,7 @@ Updated to handle the new data format where: - Includes backward compatibility with legacy format """ +import pytest import requests import time import json @@ -84,6 +85,8 @@ def parse_streaming_response( return references, response_chunks, errors +@pytest.mark.integration +@pytest.mark.requires_api def test_query_endpoint_references(): """Test /query endpoint references functionality""" @@ -187,6 +190,8 @@ def test_query_endpoint_references(): return True +@pytest.mark.integration +@pytest.mark.requires_api def test_query_stream_endpoint_references(): """Test /query/stream endpoint references functionality""" @@ -322,6 +327,8 @@ def test_query_stream_endpoint_references(): return True +@pytest.mark.integration +@pytest.mark.requires_api def test_references_consistency(): """Test references consistency across all endpoints""" @@ -472,6 +479,8 @@ def test_references_consistency(): return consistency_passed +@pytest.mark.integration +@pytest.mark.requires_api def test_aquery_data_endpoint(): """Test the /query/data endpoint""" @@ -654,6 +663,8 @@ def print_query_results(data: Dict[str, Any]): print("=" * 60) +@pytest.mark.integration +@pytest.mark.requires_api def compare_with_regular_query(): """Compare results between regular query and data query""" @@ -690,6 +701,8 @@ def compare_with_regular_query(): print(f" Regular query error: {str(e)}") +@pytest.mark.integration +@pytest.mark.requires_api def run_all_reference_tests(): """Run all reference-related tests""" diff --git a/tests/test_graph_storage.py b/tests/test_graph_storage.py index f707974b..d363346d 100644 --- a/tests/test_graph_storage.py +++ b/tests/test_graph_storage.py @@ -18,6 +18,7 @@ import os import sys import importlib import numpy as np +import pytest from dotenv import load_dotenv from ascii_colors import ASCIIColors @@ -128,6 +129,8 @@ async def initialize_graph_storage(): return None +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_basic(storage): """ Test basic graph database operations: @@ -237,6 +240,8 @@ async def test_graph_basic(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_advanced(storage): """ Test advanced graph database operations: @@ -431,6 +436,8 @@ async def test_graph_advanced(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_batch_operations(storage): """ Test batch operations of the graph database: @@ -769,6 +776,8 @@ async def test_graph_batch_operations(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_special_characters(storage): """ Test the graph database's handling of special characters: @@ -907,6 +916,8 @@ async def test_graph_special_characters(storage): return False +@pytest.mark.integration +@pytest.mark.requires_db async def test_graph_undirected_property(storage): """ Specifically test the undirected graph property of the storage: diff --git a/tests/test_lightrag_ollama_chat.py b/tests/test_lightrag_ollama_chat.py index f936a2b5..fe1cc70d 100644 --- a/tests/test_lightrag_ollama_chat.py +++ b/tests/test_lightrag_ollama_chat.py @@ -9,12 +9,11 @@ This script tests the LightRAG's Ollama compatibility interface, including: All responses use the JSON Lines format, complying with the Ollama API specification. """ +import pytest import requests import json import argparse import time -import pytest -import os from typing import Dict, Any, Optional, List, Callable from dataclasses import dataclass, asdict from datetime import datetime @@ -22,38 +21,6 @@ from pathlib import Path from enum import Enum, auto -def _check_ollama_server_available(host: str = "localhost", port: int = 9621) -> bool: - """Check if the Ollama-compatible LightRAG server is available and working. - - We test by making a simple API call to ensure the server is actually functioning, - not just that some process is listening on the port. - """ - try: - # First check if something is listening - response = requests.get(f"http://{host}:{port}/", timeout=2) - if response.status_code != 200: - return False - # Try a simple API call to ensure the server is actually working - test_data = { - "model": "lightrag:latest", - "messages": [{"role": "user", "content": "test"}], - "stream": False - } - response = requests.post(f"http://{host}:{port}/api/chat", json=test_data, timeout=5) - return response.status_code == 200 - except (requests.ConnectionError, requests.Timeout, Exception): - return False - - -# Allow override via environment variable for CI/CD -_OLLAMA_SERVER_HOST = os.getenv("OLLAMA_TEST_HOST", "localhost") -_OLLAMA_SERVER_PORT = int(os.getenv("OLLAMA_TEST_PORT", "9621")) -_SKIP_OLLAMA_TESTS = os.getenv("SKIP_OLLAMA_TESTS", "").lower() in ("1", "true", "yes") - -# Check server availability once at module level -_SERVER_AVAILABLE = not _SKIP_OLLAMA_TESTS and _check_ollama_server_available(_OLLAMA_SERVER_HOST, _OLLAMA_SERVER_PORT) - - class ErrorCode(Enum): """Error codes for MCP errors""" @@ -109,7 +76,7 @@ class OutputControl: @dataclass -class OllamaTestResult: +class TestResult: """Test result data class""" name: str @@ -123,14 +90,14 @@ class OllamaTestResult: self.timestamp = datetime.now().isoformat() -class OllamaTestStats: +class TestStats: """Test statistics""" def __init__(self): - self.results: List[OllamaTestResult] = [] + self.results: List[TestResult] = [] self.start_time = datetime.now() - def add_result(self, result: OllamaTestResult): + def add_result(self, result: TestResult): self.results.append(result) def export_results(self, path: str = "test_results.json"): @@ -307,7 +274,7 @@ def create_generate_request_data( # Global test statistics -STATS = OllamaTestStats() +STATS = TestStats() def run_test(func: Callable, name: str) -> None: @@ -320,14 +287,15 @@ def run_test(func: Callable, name: str) -> None: try: func() duration = time.time() - start_time - STATS.add_result(OllamaTestResult(name, True, duration)) + STATS.add_result(TestResult(name, True, duration)) except Exception as e: duration = time.time() - start_time - STATS.add_result(OllamaTestResult(name, False, duration, str(e))) + STATS.add_result(TestResult(name, False, duration, str(e))) raise -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_non_stream_chat() -> None: """Test non-streaming call to /api/chat endpoint""" url = get_base_url() @@ -352,7 +320,8 @@ def test_non_stream_chat() -> None: ) -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_chat() -> None: """Test streaming call to /api/chat endpoint @@ -413,7 +382,8 @@ def test_stream_chat() -> None: print() -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_query_modes() -> None: """Test different query mode prefixes @@ -473,7 +443,8 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]: return error_data.get(error_type, error_data["empty_messages"]) -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_error_handling() -> None: """Test error handling for streaming responses @@ -520,7 +491,8 @@ def test_stream_error_handling() -> None: response.close() -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_error_handling() -> None: """Test error handling for non-streaming responses @@ -568,7 +540,8 @@ def test_error_handling() -> None: print_json_response(response.json(), "Error message") -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_non_stream_generate() -> None: """Test non-streaming call to /api/generate endpoint""" url = get_base_url("generate") @@ -588,7 +561,8 @@ def test_non_stream_generate() -> None: print(json.dumps(response_json, ensure_ascii=False, indent=2)) -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_stream_generate() -> None: """Test streaming call to /api/generate endpoint""" url = get_base_url("generate") @@ -629,7 +603,8 @@ def test_stream_generate() -> None: print() -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_with_system() -> None: """Test generate with system prompt""" url = get_base_url("generate") @@ -658,7 +633,8 @@ def test_generate_with_system() -> None: ) -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_error_handling() -> None: """Test error handling for generate endpoint""" url = get_base_url("generate") @@ -684,7 +660,8 @@ def test_generate_error_handling() -> None: print_json_response(response.json(), "Error message") -@pytest.mark.skipif(not _SERVER_AVAILABLE, reason="Ollama server not available at localhost:9621") +@pytest.mark.integration +@pytest.mark.requires_api def test_generate_concurrent() -> None: """Test concurrent generate requests""" import asyncio diff --git a/tests/test_postgres_retry_integration.py b/tests/test_postgres_retry_integration.py index 71e5c47d..24f8db52 100644 --- a/tests/test_postgres_retry_integration.py +++ b/tests/test_postgres_retry_integration.py @@ -23,6 +23,8 @@ from lightrag.kg.postgres_impl import PostgreSQLDB load_dotenv(dotenv_path=".env", override=False) +@pytest.mark.integration +@pytest.mark.requires_db class TestPostgresRetryIntegration: """Integration tests for PostgreSQL retry mechanism with real database.""" diff --git a/tests/test_workspace_isolation.py b/tests/test_workspace_isolation.py index ef55cf56..5e5e6220 100644 --- a/tests/test_workspace_isolation.py +++ b/tests/test_workspace_isolation.py @@ -149,6 +149,7 @@ def _assert_no_timeline_overlap(timeline: List[Tuple[str, str]]) -> None: # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_pipeline_status_isolation(): """ @@ -203,6 +204,7 @@ async def test_pipeline_status_isolation(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_lock_mechanism(stress_test_mode, parallel_workers): """ @@ -272,6 +274,7 @@ async def test_lock_mechanism(stress_test_mode, parallel_workers): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_backward_compatibility(): """ @@ -345,6 +348,7 @@ async def test_backward_compatibility(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_multi_workspace_concurrency(): """ @@ -428,6 +432,7 @@ async def test_multi_workspace_concurrency(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_namespace_lock_reentrance(): """ @@ -501,6 +506,7 @@ async def test_namespace_lock_reentrance(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_different_namespace_lock_isolation(): """ @@ -540,6 +546,7 @@ async def test_different_namespace_lock_isolation(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_error_handling(): """ @@ -590,6 +597,7 @@ async def test_error_handling(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_update_flags_workspace_isolation(): """ @@ -719,6 +727,7 @@ async def test_update_flags_workspace_isolation(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_empty_workspace_standardization(): """ @@ -772,6 +781,7 @@ async def test_empty_workspace_standardization(): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): """ @@ -848,6 +858,9 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): } ) print(" Written to storage1: entity1, entity2") + # Persist data to disk + await storage1.index_done_callback() + print(" Persisted storage1 data to disk") # Write to storage2 await storage2.upsert( @@ -863,6 +876,9 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): } ) print(" Written to storage2: entity1, entity2") + # Persist data to disk + await storage2.index_done_callback() + print(" Persisted storage2 data to disk") # Test 10.3: Read data from each storage and verify isolation print("\nTest 10.3: Read data and verify isolation") @@ -940,6 +956,7 @@ async def test_json_kv_storage_workspace_isolation(keep_test_artifacts): # ============================================================================= +@pytest.mark.offline @pytest.mark.asyncio async def test_lightrag_end_to_end_workspace_isolation(keep_test_artifacts): """ diff --git a/tests/test_write_json_optimization.py b/tests/test_write_json_optimization.py index ea555c50..e0331390 100644 --- a/tests/test_write_json_optimization.py +++ b/tests/test_write_json_optimization.py @@ -11,9 +11,11 @@ This test verifies: import os import json import tempfile +import pytest from lightrag.utils import write_json, load_json, SanitizingJSONEncoder +@pytest.mark.offline class TestWriteJsonOptimization: """Test write_json optimization with two-stage approach"""