LightRAG/tests/test_multi_workspace_server.py
Clément THOMAS 62b2a71dda feat(api): add multi-workspace server support for multi-tenant deployments
Enable a single LightRAG server instance to serve multiple isolated workspaces
via HTTP header-based routing. This allows multi-tenant SaaS deployments where
each tenant's data is completely isolated.

Key features:
- Header-based workspace routing (LIGHTRAG-WORKSPACE, X-Workspace-ID fallback)
- Process-local pool of LightRAG instances with LRU eviction
- FastAPI dependency (get_rag) for workspace resolution per request
- Full backward compatibility - existing deployments work unchanged
- Strict multi-tenant mode option (LIGHTRAG_ALLOW_DEFAULT_WORKSPACE=false)
- Configurable pool size (LIGHTRAG_MAX_WORKSPACES_IN_POOL)
- Graceful shutdown with workspace finalization

Configuration:
- LIGHTRAG_DEFAULT_WORKSPACE: Default workspace (falls back to WORKSPACE)
- LIGHTRAG_ALLOW_DEFAULT_WORKSPACE: Require explicit header when false
- LIGHTRAG_MAX_WORKSPACES_IN_POOL: Max concurrent workspace instances (default: 50)

Files:
- New: lightrag/api/workspace_manager.py (core multi-workspace module)
- New: tests/test_multi_workspace_server.py (17 unit tests)
- New: render.yaml (Render deployment blueprint)
- Modified: All route files to use get_rag dependency
- Updated: README.md, env.example with documentation

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-01 12:07:22 +01:00

357 lines
13 KiB
Python

"""
Tests for multi-workspace server support.
This module tests the server-level multi-workspace functionality including:
- Workspace identifier validation
- WorkspacePool management and LRU eviction
- Header-based workspace routing
- Workspace isolation (documents, queries, graphs)
- Backward compatibility with single-workspace mode
- Strict multi-tenant mode
Tests are organized by user story to match the implementation plan.
"""
import asyncio
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from lightrag.api.workspace_manager import (
WorkspaceConfig,
WorkspacePool,
WorkspaceInstance,
validate_workspace_id,
get_workspace_from_request,
WORKSPACE_ID_PATTERN,
)
# =============================================================================
# Phase 2: Foundational - Unit Tests
# =============================================================================
class TestWorkspaceValidation:
"""T013: Unit tests for workspace identifier validation."""
def test_valid_workspace_ids(self):
"""Valid workspace identifiers should pass validation."""
valid_ids = [
"tenant1",
"tenant-a",
"tenant_b",
"Workspace123",
"a",
"A1b2C3",
"workspace-with-dashes",
"workspace_with_underscores",
"a" * 64, # Max length
]
for workspace_id in valid_ids:
validate_workspace_id(workspace_id) # Should not raise
def test_invalid_workspace_ids(self):
"""Invalid workspace identifiers should raise ValueError."""
invalid_ids = [
"", # Empty
"-starts-with-dash",
"_starts_with_underscore",
"has spaces",
"has/slashes",
"has\\backslashes",
"has.dots",
"a" * 65, # Too long
"../path-traversal",
"has:colons",
]
for workspace_id in invalid_ids:
with pytest.raises(ValueError):
validate_workspace_id(workspace_id)
def test_workspace_id_pattern(self):
"""Verify the regex pattern matches expected identifiers."""
assert WORKSPACE_ID_PATTERN.match("tenant1")
assert WORKSPACE_ID_PATTERN.match("tenant-a")
assert WORKSPACE_ID_PATTERN.match("tenant_b")
assert not WORKSPACE_ID_PATTERN.match("")
assert not WORKSPACE_ID_PATTERN.match("-invalid")
assert not WORKSPACE_ID_PATTERN.match("_invalid")
class TestWorkspacePool:
"""T014: Unit tests for WorkspacePool."""
@pytest.fixture
def mock_rag_factory(self):
"""Create a mock RAG factory."""
async def factory(workspace_id: str):
mock_rag = MagicMock()
mock_rag.workspace = workspace_id
mock_rag.finalize_storages = AsyncMock()
return mock_rag
return factory
@pytest.fixture
def config(self):
"""Create a test configuration."""
return WorkspaceConfig(
default_workspace="default",
allow_default_workspace=True,
max_workspaces_in_pool=3,
)
@pytest.fixture
def pool(self, config, mock_rag_factory):
"""Create a workspace pool for testing."""
return WorkspacePool(config, mock_rag_factory)
async def test_get_creates_new_instance(self, pool):
"""First request for a workspace should create a new instance."""
rag = await pool.get("tenant1")
assert rag is not None
assert rag.workspace == "tenant1"
assert pool.size == 1
async def test_get_returns_cached_instance(self, pool):
"""Subsequent requests should return the cached instance."""
rag1 = await pool.get("tenant1")
rag2 = await pool.get("tenant1")
assert rag1 is rag2
assert pool.size == 1
async def test_lru_eviction(self, pool):
"""When pool is full, LRU instance should be evicted."""
# Fill the pool (max 3)
await pool.get("tenant1")
await pool.get("tenant2")
await pool.get("tenant3")
assert pool.size == 3
# Access tenant1 to make it most recently used
await pool.get("tenant1")
# Add a new tenant, should evict tenant2 (LRU)
await pool.get("tenant4")
assert pool.size == 3
assert "tenant2" not in pool._instances
assert "tenant1" in pool._instances
assert "tenant3" in pool._instances
assert "tenant4" in pool._instances
async def test_invalid_workspace_id_rejected(self, pool):
"""Invalid workspace identifiers should be rejected."""
with pytest.raises(ValueError):
await pool.get("")
with pytest.raises(ValueError):
await pool.get("-invalid")
async def test_finalize_all(self, pool):
"""finalize_all should clean up all instances."""
await pool.get("tenant1")
await pool.get("tenant2")
assert pool.size == 2
await pool.finalize_all()
assert pool.size == 0
class TestGetWorkspaceFromRequest:
"""Tests for header extraction from requests."""
def test_primary_header(self):
"""LIGHTRAG-WORKSPACE header should be used as primary."""
request = MagicMock()
request.headers = {"LIGHTRAG-WORKSPACE": "tenant1"}
assert get_workspace_from_request(request) == "tenant1"
def test_fallback_header(self):
"""X-Workspace-ID should be used as fallback."""
request = MagicMock()
request.headers = {"X-Workspace-ID": "tenant2"}
assert get_workspace_from_request(request) == "tenant2"
def test_primary_takes_precedence(self):
"""LIGHTRAG-WORKSPACE should take precedence over X-Workspace-ID."""
request = MagicMock()
request.headers = {
"LIGHTRAG-WORKSPACE": "primary",
"X-Workspace-ID": "fallback",
}
assert get_workspace_from_request(request) == "primary"
def test_no_header_returns_none(self):
"""Missing headers should return None."""
request = MagicMock()
request.headers = {}
assert get_workspace_from_request(request) is None
def test_empty_header_returns_none(self):
"""Empty header values should return None."""
request = MagicMock()
request.headers = {"LIGHTRAG-WORKSPACE": " "}
assert get_workspace_from_request(request) is None
# =============================================================================
# Phase 3: User Story 1+2 - Isolation & Routing Tests
# =============================================================================
@pytest.mark.integration
class TestWorkspaceIsolation:
"""T015-T016: Tests for workspace data isolation."""
async def test_ingest_in_workspace_a_query_from_workspace_b_returns_nothing(self):
"""Documents ingested in workspace A should not be visible in workspace B."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_query_from_workspace_a_returns_own_documents(self):
"""Queries should return documents from the same workspace."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
@pytest.mark.integration
class TestWorkspaceRouting:
"""T017-T019: Tests for header-based workspace routing."""
async def test_lightrag_workspace_header_routes_correctly(self):
"""LIGHTRAG-WORKSPACE header should route to correct workspace."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_x_workspace_id_fallback_works(self):
"""X-Workspace-ID should work as fallback header."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_lightrag_workspace_takes_precedence(self):
"""LIGHTRAG-WORKSPACE should take precedence over X-Workspace-ID."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
# =============================================================================
# Phase 4: User Story 3 - Backward Compatibility Tests
# =============================================================================
@pytest.mark.integration
class TestBackwardCompatibility:
"""T031-T033: Tests for backward compatibility."""
async def test_no_header_uses_workspace_env_var(self):
"""Requests without headers should use WORKSPACE env var."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_existing_routes_unchanged(self):
"""Existing route paths should remain unchanged."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_response_formats_unchanged(self):
"""Response formats should remain unchanged."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
# =============================================================================
# Phase 5: User Story 4 - Strict Mode Tests
# =============================================================================
@pytest.mark.integration
class TestStrictMode:
"""T038-T040: Tests for strict multi-tenant mode."""
async def test_missing_header_returns_400_when_default_disabled(self):
"""Missing header should return 400 when default workspace disabled."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_error_message_indicates_missing_header(self):
"""Error message should clearly indicate missing header."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
async def test_missing_header_uses_default_when_enabled(self):
"""Missing header should use default when enabled."""
# TODO: Implement with actual server integration
pytest.skip("Integration test - requires running server")
# =============================================================================
# Phase 6: User Story 5 - Pool Management Tests
# =============================================================================
class TestPoolManagement:
"""T045-T048: Tests for workspace pool management."""
@pytest.fixture
def mock_rag_factory(self):
"""Create a mock RAG factory with initialization tracking."""
init_count = {"value": 0}
async def factory(workspace_id: str):
init_count["value"] += 1
mock_rag = MagicMock()
mock_rag.workspace = workspace_id
mock_rag.init_order = init_count["value"]
mock_rag.finalize_storages = AsyncMock()
return mock_rag
factory.init_count = init_count
return factory
async def test_new_workspace_initializes_on_first_request(self, mock_rag_factory):
"""New workspace should initialize on first request."""
config = WorkspaceConfig(max_workspaces_in_pool=5)
pool = WorkspacePool(config, mock_rag_factory)
rag = await pool.get("new-workspace")
assert rag.workspace == "new-workspace"
assert mock_rag_factory.init_count["value"] == 1
async def test_lru_eviction_when_pool_full(self, mock_rag_factory):
"""LRU workspace should be evicted when pool is full."""
config = WorkspaceConfig(max_workspaces_in_pool=2)
pool = WorkspacePool(config, mock_rag_factory)
await pool.get("workspace1")
await pool.get("workspace2")
assert pool.size == 2
await pool.get("workspace3")
assert pool.size == 2
assert "workspace1" not in pool._instances
async def test_concurrent_requests_share_initialization(self, mock_rag_factory):
"""Concurrent requests for same workspace should share initialization."""
config = WorkspaceConfig(max_workspaces_in_pool=5)
pool = WorkspacePool(config, mock_rag_factory)
# Start multiple concurrent requests
results = await asyncio.gather(
pool.get("shared-workspace"),
pool.get("shared-workspace"),
pool.get("shared-workspace"),
)
# All should return the same instance
assert results[0] is results[1] is results[2]
# Only one initialization should have occurred
assert mock_rag_factory.init_count["value"] == 1
async def test_max_workspaces_config_respected(self, mock_rag_factory):
"""Pool should respect max workspaces configuration."""
config = WorkspaceConfig(max_workspaces_in_pool=3)
pool = WorkspacePool(config, mock_rag_factory)
for i in range(5):
await pool.get(f"workspace{i}")
assert pool.size == 3
assert pool.max_size == 3