ragflow/test/unit_test/nlp/test_hierarchical_retrieval.py
hsparks.codes d104f59e29 feat: Implement hierarchical retrieval architecture (#11610)
This PR implements the complete three-tier hierarchical retrieval architecture
as specified in issue #11610, enabling production-grade RAG capabilities.

## Tier 1: Knowledge Base Routing
- Auto-route queries to relevant knowledge bases
- Per-KB retrieval parameters (KBRetrievalParams dataclass)
- Rule-based routing with keyword overlap scoring
- LLM-based routing with fallback to rule-based
- Configurable routing methods: auto, rule_based, llm_based, all

## Tier 2: Document Filtering
- Document-level metadata filtering within selected KBs
- Configurable metadata fields for filtering
- LLM-generated filter conditions
- Metadata similarity matching (fuzzy matching)
- Enhanced metadata generation for documents

## Tier 3: Chunk Refinement
- Parent-child chunking with summary mapping
- Custom prompts for keyword extraction
- LLM-based question generation for chunks
- Integration with existing retrieval pipeline

## Metadata Management (Batch CRUD)
- MetadataService with batch operations:
  - batch_get_metadata
  - batch_update_metadata
  - batch_delete_metadata_fields
  - batch_set_metadata_field
  - get_metadata_schema
  - search_by_metadata
  - get_metadata_statistics
  - copy_metadata
- REST API endpoints in metadata_app.py

## Integration
- HierarchicalConfig dataclass for configuration
- Integrated into Dealer class (search.py)
- Wired into agent retrieval tool
- Non-breaking: disabled by default

## Tests
- 48 unit tests covering all components
- Tests for config, routing, filtering, and metadata operations
2025-12-09 07:32:00 +01:00

699 lines
24 KiB
Python

#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Unit tests for hierarchical retrieval integration in search.py.
"""
import pytest
from unittest.mock import Mock
from rag.nlp.search import (
HierarchicalConfig,
HierarchicalResult,
KBRetrievalParams,
Dealer,
index_name,
)
class TestHierarchicalConfig:
"""Test HierarchicalConfig dataclass."""
def test_default_values(self):
"""Test default configuration values."""
config = HierarchicalConfig()
assert config.enabled is False
assert config.enable_kb_routing is True
assert config.kb_routing_threshold == 0.3
assert config.kb_top_k == 3
assert config.enable_doc_filtering is True
assert config.doc_top_k == 100
assert config.chunk_top_k == 10
def test_custom_values(self):
"""Test custom configuration values."""
config = HierarchicalConfig(
enabled=True,
kb_routing_threshold=0.5,
kb_top_k=5,
doc_top_k=50,
)
assert config.enabled is True
assert config.kb_routing_threshold == 0.5
assert config.kb_top_k == 5
assert config.doc_top_k == 50
class TestHierarchicalResult:
"""Test HierarchicalResult dataclass."""
def test_default_values(self):
"""Test default result values."""
result = HierarchicalResult()
assert result.selected_kb_ids == []
assert result.filtered_doc_ids == []
assert result.tier1_time_ms == 0.0
assert result.tier2_time_ms == 0.0
assert result.tier3_time_ms == 0.0
assert result.total_time_ms == 0.0
def test_populated_result(self):
"""Test result with data."""
result = HierarchicalResult(
selected_kb_ids=["kb1", "kb2"],
filtered_doc_ids=["doc1", "doc2", "doc3"],
tier1_time_ms=10.5,
tier2_time_ms=25.3,
tier3_time_ms=100.2,
total_time_ms=136.0,
)
assert len(result.selected_kb_ids) == 2
assert len(result.filtered_doc_ids) == 3
assert result.total_time_ms == 136.0
class TestTier1KBRouting:
"""Test Tier 1: KB Routing logic."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_routing_disabled(self, mock_dealer):
"""Test that routing returns all KBs when disabled."""
config = HierarchicalConfig(enable_kb_routing=False)
kb_ids = ["kb1", "kb2", "kb3"]
result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config)
assert result == kb_ids
def test_routing_no_kb_infos(self, mock_dealer):
"""Test that routing returns all KBs when no KB info provided."""
config = HierarchicalConfig(enable_kb_routing=True)
kb_ids = ["kb1", "kb2", "kb3"]
result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config)
assert result == kb_ids
def test_routing_few_kbs(self, mock_dealer):
"""Test that routing returns all KBs when count <= kb_top_k."""
config = HierarchicalConfig(enable_kb_routing=True, kb_top_k=5)
kb_ids = ["kb1", "kb2", "kb3"]
kb_infos = [
{"id": "kb1", "name": "Finance KB", "description": "Financial documents"},
{"id": "kb2", "name": "HR KB", "description": "Human resources"},
{"id": "kb3", "name": "Tech KB", "description": "Technical docs"},
]
result = mock_dealer._tier1_kb_routing("test query", kb_ids, kb_infos, config)
assert result == kb_ids
def test_routing_selects_relevant_kbs(self, mock_dealer):
"""Test that routing selects KBs based on keyword overlap."""
config = HierarchicalConfig(
enable_kb_routing=True,
kb_top_k=2,
kb_routing_threshold=0.1
)
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
kb_infos = [
{"id": "kb1", "name": "Finance Reports", "description": "Financial analysis and reports"},
{"id": "kb2", "name": "HR Policies", "description": "Human resources policies"},
{"id": "kb3", "name": "Technical Documentation", "description": "Engineering docs"},
{"id": "kb4", "name": "Financial Statements", "description": "Quarterly financial data"},
]
# Query about finance should select finance-related KBs
result = mock_dealer._tier1_kb_routing("financial report analysis", kb_ids, kb_infos, config)
# Should select at most kb_top_k KBs
assert len(result) <= config.kb_top_k
# Should include finance-related KBs
assert any(kb in result for kb in ["kb1", "kb4"])
class TestTier2DocumentFiltering:
"""Test Tier 2: Document Filtering logic."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_filtering_disabled(self, mock_dealer):
"""Test that filtering returns empty when disabled."""
config = HierarchicalConfig(enable_doc_filtering=False)
result = mock_dealer._tier2_document_filtering(
"test query", ["tenant1"], ["kb1"], None, config
)
assert result == []
def test_filtering_with_existing_doc_ids(self, mock_dealer):
"""Test that filtering limits existing doc_ids."""
config = HierarchicalConfig(enable_doc_filtering=True, doc_top_k=2)
doc_ids = ["doc1", "doc2", "doc3", "doc4"]
result = mock_dealer._tier2_document_filtering(
"test query", ["tenant1"], ["kb1"], doc_ids, config
)
assert len(result) == 2
assert result == ["doc1", "doc2"]
class TestHierarchicalRetrieval:
"""Test the full hierarchical retrieval flow."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance with mocked methods."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
# Mock the retrieval method
dealer.retrieval = Mock(return_value={
"total": 5,
"chunks": [
{"chunk_id": "c1", "content_with_weight": "test content 1"},
{"chunk_id": "c2", "content_with_weight": "test content 2"},
],
"doc_aggs": [],
})
return dealer
def test_hierarchical_retrieval_basic(self, mock_dealer):
"""Test basic hierarchical retrieval flow."""
config = HierarchicalConfig(
enabled=True,
enable_kb_routing=False, # Skip routing for this test
enable_doc_filtering=False, # Skip filtering for this test
)
result = mock_dealer.hierarchical_retrieval(
question="test query",
embd_mdl=Mock(),
tenant_ids=["tenant1"],
kb_ids=["kb1", "kb2"],
hierarchical_config=config,
)
# Should have chunks from retrieval
assert "chunks" in result
assert len(result["chunks"]) == 2
# Should have hierarchical metadata
assert "hierarchical_metadata" in result
metadata = result["hierarchical_metadata"]
assert "tier1_time_ms" in metadata
assert "tier2_time_ms" in metadata
assert "tier3_time_ms" in metadata
assert "total_time_ms" in metadata
def test_hierarchical_retrieval_with_kb_infos(self, mock_dealer):
"""Test hierarchical retrieval with KB information."""
config = HierarchicalConfig(
enabled=True,
enable_kb_routing=True,
kb_top_k=2,
)
kb_infos = [
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
{"id": "kb2", "name": "HR", "description": "HR policies"},
]
result = mock_dealer.hierarchical_retrieval(
question="financial report",
embd_mdl=Mock(),
tenant_ids=["tenant1"],
kb_ids=["kb1", "kb2"],
kb_infos=kb_infos,
hierarchical_config=config,
)
assert "hierarchical_metadata" in result
assert "selected_kb_ids" in result["hierarchical_metadata"]
def test_hierarchical_retrieval_empty_query(self, mock_dealer):
"""Test hierarchical retrieval with empty query."""
# Mock retrieval to return empty for empty query
mock_dealer.retrieval = Mock(return_value={
"total": 0,
"chunks": [],
"doc_aggs": [],
})
config = HierarchicalConfig(enabled=True)
result = mock_dealer.hierarchical_retrieval(
question="",
embd_mdl=Mock(),
tenant_ids=["tenant1"],
kb_ids=["kb1"],
hierarchical_config=config,
)
assert result["total"] == 0
assert result["chunks"] == []
class TestIndexName:
"""Test index_name utility function."""
def test_index_name_format(self):
"""Test index name formatting."""
assert index_name("user123") == "ragflow_user123"
assert index_name("tenant_abc") == "ragflow_tenant_abc"
class TestHierarchicalConfigAdvanced:
"""Test advanced HierarchicalConfig options."""
def test_all_config_options(self):
"""Test all configuration options."""
config = HierarchicalConfig(
enabled=True,
enable_kb_routing=True,
kb_routing_method="llm_based",
kb_routing_threshold=0.5,
kb_top_k=5,
enable_doc_filtering=True,
doc_top_k=50,
metadata_fields=["department", "doc_type"],
enable_metadata_similarity=True,
metadata_similarity_threshold=0.8,
use_llm_metadata_filter=True,
chunk_top_k=20,
enable_parent_child=True,
use_summary_mapping=True,
keyword_extraction_prompt="Extract important technical terms",
question_generation_prompt="Generate questions about the content",
)
assert config.kb_routing_method == "llm_based"
assert config.metadata_fields == ["department", "doc_type"]
assert config.enable_metadata_similarity is True
assert config.use_llm_metadata_filter is True
assert config.enable_parent_child is True
assert config.use_summary_mapping is True
assert config.keyword_extraction_prompt is not None
class TestLLMKBRouting:
"""Test LLM-based KB routing."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_llm_routing_fallback_no_model(self, mock_dealer):
"""Test LLM routing falls back when no model provided."""
config = HierarchicalConfig(
enable_kb_routing=True,
kb_routing_method="llm_based",
kb_top_k=2
)
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
kb_infos = [
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
{"id": "kb2", "name": "HR", "description": "HR policies"},
{"id": "kb3", "name": "Tech", "description": "Technical docs"},
{"id": "kb4", "name": "Legal", "description": "Legal documents"},
]
# Without chat_mdl, should fall back to rule-based
result = mock_dealer._tier1_kb_routing(
"financial report", kb_ids, kb_infos, config, chat_mdl=None
)
# Should still return results (from rule-based fallback)
assert len(result) <= config.kb_top_k
def test_routing_method_all(self, mock_dealer):
"""Test 'all' routing method returns all KBs."""
config = HierarchicalConfig(
enable_kb_routing=True,
kb_routing_method="all",
kb_top_k=2
)
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
kb_infos = [
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
{"id": "kb2", "name": "HR", "description": "HR policies"},
]
result = mock_dealer._tier1_kb_routing(
"any query", kb_ids, kb_infos, config
)
assert result == kb_ids
class TestMetadataSimilarityFilter:
"""Test metadata similarity filtering."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_similarity_filter_no_model(self, mock_dealer):
"""Test similarity filter returns empty without embedding model."""
config = HierarchicalConfig(
enable_metadata_similarity=True,
metadata_similarity_threshold=0.7
)
doc_metadata = [
{"id": "doc1", "name": "Finance Report", "summary": "Q1 financial analysis"},
{"id": "doc2", "name": "HR Policy", "summary": "Employee guidelines"},
]
result = mock_dealer._metadata_similarity_filter(
"financial analysis", doc_metadata, config, embd_mdl=None
)
assert result == []
def test_similarity_filter_no_metadata(self, mock_dealer):
"""Test similarity filter returns empty without metadata."""
config = HierarchicalConfig(enable_metadata_similarity=True)
mock_embd = Mock()
result = mock_dealer._metadata_similarity_filter(
"test query", [], config, embd_mdl=mock_embd
)
assert result == []
class TestParentChildRetrieval:
"""Test parent-child chunking with summary mapping."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance with mocked methods."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
# Mock the search method
mock_search_result = Mock()
mock_search_result.ids = ["chunk1", "chunk2"]
mock_search_result.field = {
"chunk1": {
"content_ltks": "parent content",
"content_with_weight": "parent content",
"doc_id": "doc1",
"docnm_kwd": "test.pdf",
"kb_id": "kb1",
"mom_id": "",
},
"chunk2": {
"content_ltks": "child content",
"content_with_weight": "child content",
"doc_id": "doc1",
"docnm_kwd": "test.pdf",
"kb_id": "kb1",
"mom_id": "chunk1",
},
}
mock_search_result.total = 2
mock_search_result.query_vector = [0.1] * 768
mock_search_result.highlight = {}
dealer.search = Mock(return_value=mock_search_result)
dealer.rerank = Mock(return_value=([0.9, 0.8], [0.5, 0.4], [0.9, 0.8]))
return dealer
def test_summary_mapping_config(self):
"""Test summary mapping configuration."""
config = HierarchicalConfig(
enable_parent_child=True,
use_summary_mapping=True,
)
assert config.enable_parent_child is True
assert config.use_summary_mapping is True
class TestCustomKeywordExtraction:
"""Test customizable keyword extraction."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_keyword_extraction_important(self, mock_dealer):
"""Test keyword extraction with 'important' prompt."""
chunks = [
{
"content_with_weight": "The API_KEY and DatabaseConnection are critical components.",
"important_kwd": ["existing_term"], # Pre-existing keyword
}
]
result = mock_dealer._apply_custom_keyword_extraction(
chunks, "Extract important technical terms", None
)
# Should return the chunk (possibly with extracted keywords)
assert len(result) == 1
# Should preserve existing keywords at minimum
assert "existing_term" in result[0].get("important_kwd", [])
def test_keyword_extraction_question(self, mock_dealer):
"""Test keyword extraction with 'question' prompt."""
chunks = [
{
"content_with_weight": "This section explains the authentication flow for the system. Users must provide valid credentials.",
"important_kwd": [],
}
]
result = mock_dealer._apply_custom_keyword_extraction(
chunks, "Generate questions about the content", None
)
assert len(result) == 1
# Should have a question hint
assert "question_hint" in result[0]
def test_keyword_extraction_empty_content(self, mock_dealer):
"""Test keyword extraction with empty content."""
chunks = [{"content_with_weight": "", "important_kwd": []}]
result = mock_dealer._apply_custom_keyword_extraction(
chunks, "Extract important terms", None
)
assert len(result) == 1
class TestKBRetrievalParams:
"""Test per-KB retrieval parameters."""
def test_default_params(self):
"""Test default KB params."""
params = KBRetrievalParams(kb_id="kb1")
assert params.kb_id == "kb1"
assert params.vector_similarity_weight == 0.7
assert params.similarity_threshold == 0.2
assert params.top_k == 1024
assert params.rerank_enabled is True
def test_custom_params(self):
"""Test custom KB params."""
params = KBRetrievalParams(
kb_id="finance_kb",
vector_similarity_weight=0.9,
similarity_threshold=0.3,
top_k=500,
rerank_enabled=False
)
assert params.kb_id == "finance_kb"
assert params.vector_similarity_weight == 0.9
assert params.similarity_threshold == 0.3
assert params.top_k == 500
assert params.rerank_enabled is False
def test_kb_params_in_config(self):
"""Test KB params integration in HierarchicalConfig."""
kb_params = {
"kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.8),
"kb2": KBRetrievalParams(kb_id="kb2", similarity_threshold=0.4),
}
config = HierarchicalConfig(
enabled=True,
kb_params=kb_params
)
assert len(config.kb_params) == 2
assert config.kb_params["kb1"].vector_similarity_weight == 0.8
assert config.kb_params["kb2"].similarity_threshold == 0.4
class TestLLMQuestionGeneration:
"""Test LLM-based question generation."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_question_generation_no_model(self, mock_dealer):
"""Test question generation returns chunks unchanged without model."""
chunks = [{"content_with_weight": "Test content", "important_kwd": []}]
result = mock_dealer._apply_llm_question_generation(chunks, None, None)
assert result == chunks
assert "generated_questions" not in result[0]
def test_question_generation_with_mock_model(self, mock_dealer):
"""Test question generation with mock LLM."""
mock_chat = Mock()
mock_chat.chat = Mock(return_value="What is the main topic?\nHow does this work?")
chunks = [{
"content_with_weight": "This is a detailed explanation of the authentication system. It uses OAuth2 for secure access.",
"important_kwd": []
}]
result = mock_dealer._apply_llm_question_generation(chunks, None, mock_chat)
assert len(result) == 1
assert "generated_questions" in result[0]
assert len(result[0]["generated_questions"]) == 2
class TestDocumentMetadataGeneration:
"""Test document metadata generation."""
@pytest.fixture
def mock_dealer(self):
"""Create a mock Dealer instance."""
mock_datastore = Mock()
dealer = Dealer(mock_datastore)
return dealer
def test_metadata_generation_no_model(self, mock_dealer):
"""Test metadata generation returns empty without model."""
result = mock_dealer.generate_document_metadata("doc1", "content", None, None)
assert result["doc_id"] == "doc1"
assert result["generated"] is True
assert result["summary"] == ""
assert result["topics"] == []
def test_metadata_generation_with_mock_model(self, mock_dealer):
"""Test metadata generation with mock LLM."""
mock_chat = Mock()
mock_chat.chat = Mock(return_value="""SUMMARY: This document explains authentication.
TOPICS: security, OAuth, authentication
QUESTIONS:
- What is OAuth?
- How does authentication work?
CATEGORY: Technical Documentation""")
result = mock_dealer.generate_document_metadata(
"doc1",
"This is content about authentication and security.",
mock_chat,
None
)
assert result["doc_id"] == "doc1"
assert "authentication" in result["summary"]
assert len(result["topics"]) == 3
assert "security" in result["topics"]
assert len(result["suggested_questions"]) == 2
assert result["category"] == "Technical Documentation"
class TestFullHierarchicalConfig:
"""Test complete hierarchical configuration."""
def test_all_features_enabled(self):
"""Test config with all features enabled."""
config = HierarchicalConfig(
enabled=True,
# Tier 1
enable_kb_routing=True,
kb_routing_method="llm_based",
kb_routing_threshold=0.4,
kb_top_k=5,
kb_params={
"kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.9)
},
# Tier 2
enable_doc_filtering=True,
doc_top_k=50,
metadata_fields=["department", "author"],
enable_metadata_similarity=True,
metadata_similarity_threshold=0.8,
use_llm_metadata_filter=True,
# Tier 3
chunk_top_k=20,
enable_parent_child=True,
use_summary_mapping=True,
# Prompts
keyword_extraction_prompt="Extract domain-specific terms",
question_generation_prompt="Generate FAQ questions",
use_llm_question_generation=True,
)
assert config.enabled is True
assert config.kb_routing_method == "llm_based"
assert len(config.kb_params) == 1
assert config.enable_metadata_similarity is True
assert config.use_llm_metadata_filter is True
assert config.enable_parent_child is True
assert config.use_summary_mapping is True
assert config.use_llm_question_generation is True
if __name__ == "__main__":
pytest.main([__file__, "-v"])