diff --git a/api/db/db_models.py b/api/db/db_models.py index 3d2192b2d..b841fdac8 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -869,6 +869,13 @@ class Dialog(DataBaseModel): kb_ids = JSONField(null=False, default=[]) status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True) + + # Hierarchical Retrieval Configuration + hierarchical_retrieval_config = JSONField( + null=True, + default=None, + help_text="Configuration for three-tier hierarchical retrieval architecture" + ) class Meta: db_table = "dialog" @@ -1396,4 +1403,10 @@ def migrate_db(): except Exception: pass + # Hierarchical Retrieval Configuration + try: + migrate(migrator.add_column("dialog", "hierarchical_retrieval_config", JSONField(null=True, default=None, help_text="Configuration for three-tier hierarchical retrieval architecture"))) + except Exception: + pass + logging.disable(logging.NOTSET) diff --git a/rag/retrieval/hierarchical_retrieval.py b/rag/retrieval/hierarchical_retrieval.py new file mode 100644 index 000000000..1a49d9edd --- /dev/null +++ b/rag/retrieval/hierarchical_retrieval.py @@ -0,0 +1,406 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Hierarchical Retrieval Architecture for Production-Grade RAG + +Implements a three-tier retrieval system: +1. Tier 1: Knowledge Base Routing - Routes queries to relevant KBs +2. Tier 2: Document Filtering - Filters by metadata +3. Tier 3: Chunk Refinement - Precise vector retrieval + +This architecture addresses scalability and precision limitations +in production environments with large document collections. +""" + +import logging +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field + + +@dataclass +class RetrievalConfig: + """Configuration for hierarchical retrieval""" + + # Tier 1: KB Routing + enable_kb_routing: bool = True + kb_routing_method: str = "auto" # "auto", "rule_based", "llm_based", "all" + kb_routing_threshold: float = 0.5 + + # Tier 2: Document Filtering + enable_doc_filtering: bool = True + metadata_fields: List[str] = field(default_factory=list) # Key metadata fields to use + enable_metadata_similarity: bool = False # Fuzzy matching for text metadata + metadata_similarity_threshold: float = 0.7 + + # Tier 3: Chunk Refinement + enable_parent_child_chunking: bool = False + use_summary_mapping: bool = False + chunk_refinement_top_k: int = 10 + + # General settings + max_candidates_per_tier: int = 100 + enable_hybrid_search: bool = True + vector_weight: float = 0.7 + keyword_weight: float = 0.3 + + +@dataclass +class RetrievalResult: + """Result from hierarchical retrieval""" + + query: str + selected_kbs: List[str] = field(default_factory=list) + filtered_docs: List[Dict[str, Any]] = field(default_factory=list) + retrieved_chunks: List[Dict[str, Any]] = field(default_factory=list) + + # Metadata about retrieval process + tier1_candidates: int = 0 + tier2_candidates: int = 0 + tier3_candidates: int = 0 + + total_time_ms: float = 0.0 + tier1_time_ms: float = 0.0 + tier2_time_ms: float = 0.0 + tier3_time_ms: float = 0.0 + + +class HierarchicalRetrieval: + """ + Three-tier hierarchical retrieval system for production-grade RAG. + + This class orchestrates the retrieval process across three tiers: + - Tier 1: Routes query to relevant knowledge bases + - Tier 2: Filters documents by metadata + - Tier 3: Performs precise chunk-level retrieval + """ + + def __init__(self, config: Optional[RetrievalConfig] = None): + """ + Initialize hierarchical retrieval system. + + Args: + config: Configuration for retrieval behavior + """ + self.config = config or RetrievalConfig() + self.logger = logging.getLogger(__name__) + + def retrieve( + self, + query: str, + kb_ids: List[str], + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None + ) -> RetrievalResult: + """ + Perform hierarchical retrieval. + + Args: + query: User query + kb_ids: List of knowledge base IDs to search + top_k: Number of final chunks to return + filters: Optional metadata filters + + Returns: + RetrievalResult with chunks and metadata + """ + import time + start_time = time.time() + + result = RetrievalResult(query=query) + + # Tier 1: Knowledge Base Routing + tier1_start = time.time() + selected_kbs = self._tier1_kb_routing(query, kb_ids) + result.selected_kbs = selected_kbs + result.tier1_candidates = len(selected_kbs) + result.tier1_time_ms = (time.time() - tier1_start) * 1000 + + if not selected_kbs: + self.logger.warning(f"No knowledge bases selected for query: {query}") + result.total_time_ms = (time.time() - start_time) * 1000 + return result + + # Tier 2: Document Filtering + tier2_start = time.time() + filtered_docs = self._tier2_document_filtering( + query, selected_kbs, filters + ) + result.filtered_docs = filtered_docs + result.tier2_candidates = len(filtered_docs) + result.tier2_time_ms = (time.time() - tier2_start) * 1000 + + if not filtered_docs: + self.logger.warning(f"No documents passed filtering for query: {query}") + result.total_time_ms = (time.time() - start_time) * 1000 + return result + + # Tier 3: Chunk Refinement + tier3_start = time.time() + chunks = self._tier3_chunk_refinement( + query, filtered_docs, top_k + ) + result.retrieved_chunks = chunks + result.tier3_candidates = len(chunks) + result.tier3_time_ms = (time.time() - tier3_start) * 1000 + + result.total_time_ms = (time.time() - start_time) * 1000 + + self.logger.info( + f"Hierarchical retrieval completed: " + f"{result.tier1_candidates} KBs -> " + f"{result.tier2_candidates} docs -> " + f"{result.tier3_candidates} chunks " + f"in {result.total_time_ms:.2f}ms" + ) + + return result + + def _tier1_kb_routing( + self, + query: str, + kb_ids: List[str] + ) -> List[str]: + """ + Tier 1: Route query to relevant knowledge bases. + + Args: + query: User query + kb_ids: Available knowledge base IDs + + Returns: + List of selected KB IDs + """ + if not self.config.enable_kb_routing: + return kb_ids + + method = self.config.kb_routing_method + + if method == "all": + # Use all provided KBs + return kb_ids + + elif method == "rule_based": + # Rule-based routing (placeholder for custom rules) + return self._rule_based_routing(query, kb_ids) + + elif method == "llm_based": + # LLM-based routing (placeholder for LLM integration) + return self._llm_based_routing(query, kb_ids) + + else: # "auto" + # Auto mode: intelligent routing based on query analysis + return self._auto_routing(query, kb_ids) + + def _rule_based_routing( + self, + query: str, + kb_ids: List[str] + ) -> List[str]: + """ + Rule-based KB routing using predefined rules. + + This is a placeholder for custom routing logic. + Users can extend this with domain-specific rules. + """ + # TODO: Implement rule-based routing + # For now, return all KBs + return kb_ids + + def _llm_based_routing( + self, + query: str, + kb_ids: List[str] + ) -> List[str]: + """ + LLM-based KB routing using language model. + + This is a placeholder for LLM integration. + """ + # TODO: Implement LLM-based routing + # For now, return all KBs + return kb_ids + + def _auto_routing( + self, + query: str, + kb_ids: List[str] + ) -> List[str]: + """ + Auto routing using query analysis. + + This is a placeholder for intelligent routing. + """ + # TODO: Implement auto routing with query analysis + # For now, return all KBs + return kb_ids + + def _tier2_document_filtering( + self, + query: str, + kb_ids: List[str], + filters: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + Tier 2: Filter documents by metadata. + + Args: + query: User query + kb_ids: Selected knowledge base IDs + filters: Metadata filters + + Returns: + List of filtered documents + """ + if not self.config.enable_doc_filtering: + # Skip filtering, return placeholder + return [] + + # TODO: Implement document filtering logic + # This would integrate with the existing document service + # to filter by metadata fields + + filtered_docs = [] + + # Placeholder: would query documents from selected KBs + # and apply metadata filters + + return filtered_docs + + def _tier3_chunk_refinement( + self, + query: str, + filtered_docs: List[Dict[str, Any]], + top_k: int + ) -> List[Dict[str, Any]]: + """ + Tier 3: Perform precise chunk-level retrieval. + + Args: + query: User query + filtered_docs: Documents that passed Tier 2 filtering + top_k: Number of chunks to return + + Returns: + List of retrieved chunks + """ + # TODO: Implement chunk refinement logic + # This would integrate with existing vector search + # and optionally use parent-child chunking + + chunks = [] + + # Placeholder: would perform vector search within + # the filtered document set + + return chunks + + +class KBRouter: + """ + Knowledge Base Router for Tier 1. + + Handles routing logic to select relevant KBs based on query intent. + """ + + def __init__(self): + self.logger = logging.getLogger(__name__) + + def route( + self, + query: str, + available_kbs: List[Dict[str, Any]], + method: str = "auto" + ) -> List[str]: + """ + Route query to relevant knowledge bases. + + Args: + query: User query + available_kbs: List of available KB metadata + method: Routing method ("auto", "rule_based", "llm_based") + + Returns: + List of selected KB IDs + """ + # TODO: Implement routing logic + return [kb["id"] for kb in available_kbs] + + +class DocumentFilter: + """ + Document Filter for Tier 2. + + Handles metadata-based document filtering. + """ + + def __init__(self): + self.logger = logging.getLogger(__name__) + + def filter( + self, + query: str, + documents: List[Dict[str, Any]], + metadata_fields: List[str], + filters: Optional[Dict[str, Any]] = None + ) -> List[Dict[str, Any]]: + """ + Filter documents by metadata. + + Args: + query: User query + documents: List of documents to filter + metadata_fields: Key metadata fields to consider + filters: Explicit metadata filters + + Returns: + Filtered list of documents + """ + # TODO: Implement filtering logic + return documents + + +class ChunkRefiner: + """ + Chunk Refiner for Tier 3. + + Handles precise chunk-level retrieval with optional parent-child support. + """ + + def __init__(self): + self.logger = logging.getLogger(__name__) + + def refine( + self, + query: str, + doc_ids: List[str], + top_k: int, + use_parent_child: bool = False + ) -> List[Dict[str, Any]]: + """ + Perform precise chunk retrieval. + + Args: + query: User query + doc_ids: Document IDs to search within + top_k: Number of chunks to return + use_parent_child: Whether to use parent-child chunking + + Returns: + List of retrieved chunks + """ + # TODO: Implement chunk refinement logic + return [] diff --git a/test/unit_test/retrieval/__init__.py b/test/unit_test/retrieval/__init__.py new file mode 100644 index 000000000..177b91dd0 --- /dev/null +++ b/test/unit_test/retrieval/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/test/unit_test/retrieval/test_hierarchical_retrieval.py b/test/unit_test/retrieval/test_hierarchical_retrieval.py new file mode 100644 index 000000000..ba763a7da --- /dev/null +++ b/test/unit_test/retrieval/test_hierarchical_retrieval.py @@ -0,0 +1,499 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Hierarchical Retrieval Architecture + +Tests the three-tier retrieval system without requiring database or +external dependencies. +""" + +import pytest +from rag.retrieval.hierarchical_retrieval import ( + HierarchicalRetrieval, + RetrievalConfig, + RetrievalResult, + KBRouter, + DocumentFilter, + ChunkRefiner +) + + +class TestRetrievalConfig: + """Test RetrievalConfig dataclass""" + + def test_default_config(self): + """Test default configuration values""" + config = RetrievalConfig() + + assert config.enable_kb_routing is True + assert config.kb_routing_method == "auto" + assert config.enable_doc_filtering is True + assert config.enable_parent_child_chunking is False + assert config.vector_weight == 0.7 + assert config.keyword_weight == 0.3 + + def test_custom_config(self): + """Test custom configuration""" + config = RetrievalConfig( + enable_kb_routing=False, + kb_routing_method="rule_based", + metadata_fields=["doc_type", "department"], + chunk_refinement_top_k=20 + ) + + assert config.enable_kb_routing is False + assert config.kb_routing_method == "rule_based" + assert "doc_type" in config.metadata_fields + assert config.chunk_refinement_top_k == 20 + + +class TestRetrievalResult: + """Test RetrievalResult dataclass""" + + def test_empty_result(self): + """Test empty result initialization""" + result = RetrievalResult(query="test query") + + assert result.query == "test query" + assert len(result.selected_kbs) == 0 + assert len(result.filtered_docs) == 0 + assert len(result.retrieved_chunks) == 0 + assert result.total_time_ms == 0.0 + + def test_result_with_data(self): + """Test result with data""" + result = RetrievalResult( + query="test query", + selected_kbs=["kb1", "kb2"], + filtered_docs=[{"id": "doc1"}, {"id": "doc2"}], + retrieved_chunks=[{"id": "chunk1"}], + tier1_candidates=2, + tier2_candidates=2, + tier3_candidates=1, + total_time_ms=150.5 + ) + + assert len(result.selected_kbs) == 2 + assert len(result.filtered_docs) == 2 + assert len(result.retrieved_chunks) == 1 + assert result.tier1_candidates == 2 + assert result.total_time_ms == 150.5 + + +class TestHierarchicalRetrieval: + """Test HierarchicalRetrieval main class""" + + def test_initialization_default_config(self): + """Test initialization with default config""" + retrieval = HierarchicalRetrieval() + + assert retrieval.config is not None + assert retrieval.config.enable_kb_routing is True + + def test_initialization_custom_config(self): + """Test initialization with custom config""" + config = RetrievalConfig(enable_kb_routing=False) + retrieval = HierarchicalRetrieval(config) + + assert retrieval.config.enable_kb_routing is False + + def test_retrieve_basic(self): + """Test basic retrieval flow""" + retrieval = HierarchicalRetrieval() + + result = retrieval.retrieve( + query="test query", + kb_ids=["kb1", "kb2"], + top_k=10 + ) + + assert isinstance(result, RetrievalResult) + assert result.query == "test query" + assert result.total_time_ms >= 0 + + def test_retrieve_with_filters(self): + """Test retrieval with metadata filters""" + retrieval = HierarchicalRetrieval() + + result = retrieval.retrieve( + query="test query", + kb_ids=["kb1"], + top_k=5, + filters={"department": "HR"} + ) + + assert isinstance(result, RetrievalResult) + + def test_retrieve_empty_kb_list(self): + """Test retrieval with empty KB list""" + retrieval = HierarchicalRetrieval() + + result = retrieval.retrieve( + query="test query", + kb_ids=[], + top_k=10 + ) + + # Should return empty result + assert len(result.selected_kbs) == 0 + assert len(result.retrieved_chunks) == 0 + + def test_tier1_kb_routing_disabled(self): + """Test KB routing when disabled""" + config = RetrievalConfig(enable_kb_routing=False) + retrieval = HierarchicalRetrieval(config) + + kb_ids = ["kb1", "kb2", "kb3"] + selected = retrieval._tier1_kb_routing("test query", kb_ids) + + # Should return all KBs when disabled + assert selected == kb_ids + + def test_tier1_kb_routing_all_method(self): + """Test KB routing with 'all' method""" + config = RetrievalConfig(kb_routing_method="all") + retrieval = HierarchicalRetrieval(config) + + kb_ids = ["kb1", "kb2", "kb3"] + selected = retrieval._tier1_kb_routing("test query", kb_ids) + + assert selected == kb_ids + + def test_tier1_kb_routing_rule_based(self): + """Test KB routing with rule-based method""" + config = RetrievalConfig(kb_routing_method="rule_based") + retrieval = HierarchicalRetrieval(config) + + kb_ids = ["kb1", "kb2"] + selected = retrieval._tier1_kb_routing("test query", kb_ids) + + # Currently returns all KBs (placeholder implementation) + assert isinstance(selected, list) + + def test_tier1_kb_routing_llm_based(self): + """Test KB routing with LLM-based method""" + config = RetrievalConfig(kb_routing_method="llm_based") + retrieval = HierarchicalRetrieval(config) + + kb_ids = ["kb1", "kb2"] + selected = retrieval._tier1_kb_routing("test query", kb_ids) + + # Currently returns all KBs (placeholder implementation) + assert isinstance(selected, list) + + def test_tier1_kb_routing_auto(self): + """Test KB routing with auto method""" + config = RetrievalConfig(kb_routing_method="auto") + retrieval = HierarchicalRetrieval(config) + + kb_ids = ["kb1", "kb2"] + selected = retrieval._tier1_kb_routing("test query", kb_ids) + + assert isinstance(selected, list) + + def test_tier2_document_filtering_disabled(self): + """Test document filtering when disabled""" + config = RetrievalConfig(enable_doc_filtering=False) + retrieval = HierarchicalRetrieval(config) + + docs = retrieval._tier2_document_filtering( + "test query", + ["kb1"], + None + ) + + # Should return empty list when disabled + assert isinstance(docs, list) + + def test_tier2_document_filtering_with_filters(self): + """Test document filtering with metadata filters""" + retrieval = HierarchicalRetrieval() + + docs = retrieval._tier2_document_filtering( + "test query", + ["kb1"], + {"department": "HR"} + ) + + assert isinstance(docs, list) + + def test_tier3_chunk_refinement(self): + """Test chunk refinement""" + retrieval = HierarchicalRetrieval() + + chunks = retrieval._tier3_chunk_refinement( + "test query", + [{"id": "doc1"}], + top_k=10 + ) + + assert isinstance(chunks, list) + + def test_timing_metrics(self): + """Test that timing metrics are recorded""" + retrieval = HierarchicalRetrieval() + + result = retrieval.retrieve( + query="test query", + kb_ids=["kb1"], + top_k=5 + ) + + # All timing metrics should be non-negative + assert result.tier1_time_ms >= 0 + assert result.tier2_time_ms >= 0 + assert result.tier3_time_ms >= 0 + assert result.total_time_ms >= 0 + + # Total should be sum of tiers (approximately) + assert result.total_time_ms >= result.tier1_time_ms + + +class TestKBRouter: + """Test KBRouter class""" + + def test_initialization(self): + """Test KBRouter initialization""" + router = KBRouter() + assert router is not None + + def test_route_basic(self): + """Test basic routing""" + router = KBRouter() + + available_kbs = [ + {"id": "kb1", "name": "HR Knowledge Base"}, + {"id": "kb2", "name": "Finance Knowledge Base"} + ] + + selected = router.route( + query="What is the vacation policy?", + available_kbs=available_kbs, + method="auto" + ) + + assert isinstance(selected, list) + assert len(selected) > 0 + + def test_route_empty_kbs(self): + """Test routing with empty KB list""" + router = KBRouter() + + selected = router.route( + query="test query", + available_kbs=[], + method="auto" + ) + + assert selected == [] + + @pytest.mark.parametrize("method", ["auto", "rule_based", "llm_based"]) + def test_route_different_methods(self, method): + """Test routing with different methods""" + router = KBRouter() + + available_kbs = [{"id": "kb1"}, {"id": "kb2"}] + + selected = router.route( + query="test query", + available_kbs=available_kbs, + method=method + ) + + assert isinstance(selected, list) + + +class TestDocumentFilter: + """Test DocumentFilter class""" + + def test_initialization(self): + """Test DocumentFilter initialization""" + filter_obj = DocumentFilter() + assert filter_obj is not None + + def test_filter_basic(self): + """Test basic filtering""" + filter_obj = DocumentFilter() + + documents = [ + {"id": "doc1", "department": "HR"}, + {"id": "doc2", "department": "Finance"}, + {"id": "doc3", "department": "HR"} + ] + + filtered = filter_obj.filter( + query="HR policies", + documents=documents, + metadata_fields=["department"], + filters=None + ) + + assert isinstance(filtered, list) + + def test_filter_with_explicit_filters(self): + """Test filtering with explicit filters""" + filter_obj = DocumentFilter() + + documents = [ + {"id": "doc1", "department": "HR"}, + {"id": "doc2", "department": "Finance"} + ] + + filtered = filter_obj.filter( + query="test", + documents=documents, + metadata_fields=["department"], + filters={"department": "HR"} + ) + + # Currently returns all docs (placeholder) + assert isinstance(filtered, list) + + def test_filter_empty_documents(self): + """Test filtering with empty document list""" + filter_obj = DocumentFilter() + + filtered = filter_obj.filter( + query="test", + documents=[], + metadata_fields=["department"], + filters=None + ) + + assert filtered == [] + + +class TestChunkRefiner: + """Test ChunkRefiner class""" + + def test_initialization(self): + """Test ChunkRefiner initialization""" + refiner = ChunkRefiner() + assert refiner is not None + + def test_refine_basic(self): + """Test basic chunk refinement""" + refiner = ChunkRefiner() + + chunks = refiner.refine( + query="test query", + doc_ids=["doc1", "doc2"], + top_k=10, + use_parent_child=False + ) + + assert isinstance(chunks, list) + + def test_refine_with_parent_child(self): + """Test refinement with parent-child chunking""" + refiner = ChunkRefiner() + + chunks = refiner.refine( + query="test query", + doc_ids=["doc1"], + top_k=5, + use_parent_child=True + ) + + assert isinstance(chunks, list) + + def test_refine_empty_doc_ids(self): + """Test refinement with empty doc IDs""" + refiner = ChunkRefiner() + + chunks = refiner.refine( + query="test query", + doc_ids=[], + top_k=10, + use_parent_child=False + ) + + assert chunks == [] + + +class TestIntegrationScenarios: + """Test end-to-end integration scenarios""" + + def test_full_retrieval_flow(self): + """Test complete retrieval flow""" + config = RetrievalConfig( + enable_kb_routing=True, + enable_doc_filtering=True, + chunk_refinement_top_k=10 + ) + + retrieval = HierarchicalRetrieval(config) + + result = retrieval.retrieve( + query="What is the company vacation policy?", + kb_ids=["hr_kb", "policies_kb"], + top_k=10, + filters={"department": "HR"} + ) + + # Verify result structure + assert isinstance(result, RetrievalResult) + assert result.query == "What is the company vacation policy?" + assert result.total_time_ms > 0 + + # Verify tier progression + assert result.tier1_candidates >= 0 + assert result.tier2_candidates >= 0 + assert result.tier3_candidates >= 0 + + def test_retrieval_with_all_features_enabled(self): + """Test retrieval with all features enabled""" + config = RetrievalConfig( + enable_kb_routing=True, + kb_routing_method="auto", + enable_doc_filtering=True, + enable_metadata_similarity=True, + enable_parent_child_chunking=True, + use_summary_mapping=True, + enable_hybrid_search=True + ) + + retrieval = HierarchicalRetrieval(config) + + result = retrieval.retrieve( + query="test query", + kb_ids=["kb1", "kb2", "kb3"], + top_k=20 + ) + + assert isinstance(result, RetrievalResult) + + def test_retrieval_with_minimal_config(self): + """Test retrieval with minimal configuration""" + config = RetrievalConfig( + enable_kb_routing=False, + enable_doc_filtering=False + ) + + retrieval = HierarchicalRetrieval(config) + + result = retrieval.retrieve( + query="test query", + kb_ids=["kb1"], + top_k=5 + ) + + assert isinstance(result, RetrievalResult) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])