# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ Hierarchical Retrieval Architecture for Production-Grade RAG Implements a three-tier retrieval system: 1. Tier 1: Knowledge Base Routing - Routes queries to relevant KBs 2. Tier 2: Document Filtering - Filters by metadata 3. Tier 3: Chunk Refinement - Precise vector retrieval This architecture addresses scalability and precision limitations in production environments with large document collections. """ import logging import time from typing import List, Dict, Any, Optional from dataclasses import dataclass, field from rag.nlp import rag_tokenizer from rag.nlp.search import index_name @dataclass class RetrievalConfig: """Configuration for hierarchical retrieval""" # Tier 1: KB Routing enable_kb_routing: bool = True kb_routing_method: str = "auto" # "auto", "rule_based", "llm_based", "all" kb_routing_threshold: float = 0.5 kb_top_k: int = 3 # Max number of KBs to select # Tier 2: Document Filtering enable_doc_filtering: bool = True metadata_fields: List[str] = field(default_factory=list) enable_metadata_similarity: bool = False metadata_similarity_threshold: float = 0.7 doc_top_k: int = 50 # Max documents to pass to tier 3 # Tier 3: Chunk Refinement enable_parent_child_chunking: bool = False use_summary_mapping: bool = False chunk_refinement_top_k: int = 10 # Search parameters similarity_threshold: float = 0.2 vector_weight: float = 0.7 keyword_weight: float = 0.3 rerank: bool = True @dataclass class RetrievalResult: """Result from hierarchical retrieval""" query: str selected_kbs: List[str] = field(default_factory=list) filtered_docs: List[Dict[str, Any]] = field(default_factory=list) retrieved_chunks: List[Dict[str, Any]] = field(default_factory=list) # Metadata about retrieval process tier1_candidates: int = 0 tier2_candidates: int = 0 tier3_candidates: int = 0 total_time_ms: float = 0.0 tier1_time_ms: float = 0.0 tier2_time_ms: float = 0.0 tier3_time_ms: float = 0.0 class HierarchicalRetrieval: """ Three-tier hierarchical retrieval system for production-grade RAG. Integrates with RAGFlow's existing search infrastructure to provide scalable, accurate retrieval for large document collections. """ def __init__(self, config: Optional[RetrievalConfig] = None, datastore_conn=None, embedding_model=None): """ Initialize hierarchical retrieval system. Args: config: Configuration for retrieval behavior datastore_conn: RAGFlow DocStoreConnection instance embedding_model: Embedding model for vector search """ self.config = config or RetrievalConfig() self.datastore = datastore_conn self.embedding_model = embedding_model self.logger = logging.getLogger(__name__) def retrieve( self, query: str, kb_ids: List[str], kb_infos: Optional[List[Dict[str, Any]]] = None, top_k: int = 10, filters: Optional[Dict[str, Any]] = None ) -> RetrievalResult: """ Perform hierarchical retrieval. Args: query: User query kb_ids: List of knowledge base IDs to search kb_infos: Optional KB metadata (name, description) top_k: Number of final chunks to return filters: Optional metadata filters Returns: RetrievalResult with chunks and metadata """ start_time = time.time() result = RetrievalResult(query=query) # Tier 1: Knowledge Base Routing tier1_start = time.time() selected_kbs = self._tier1_kb_routing(query, kb_ids, kb_infos) result.selected_kbs = selected_kbs result.tier1_candidates = len(selected_kbs) result.tier1_time_ms = (time.time() - tier1_start) * 1000 if not selected_kbs: self.logger.warning(f"No knowledge bases selected for query: {query}") result.total_time_ms = (time.time() - start_time) * 1000 return result # Tier 2: Document Filtering tier2_start = time.time() filtered_docs = self._tier2_document_filtering( query, selected_kbs, filters ) result.filtered_docs = filtered_docs result.tier2_candidates = len(filtered_docs) result.tier2_time_ms = (time.time() - tier2_start) * 1000 # Tier 3: Chunk Refinement tier3_start = time.time() chunks = self._tier3_chunk_refinement( query, selected_kbs, filtered_docs, top_k ) result.retrieved_chunks = chunks result.tier3_candidates = len(chunks) result.tier3_time_ms = (time.time() - tier3_start) * 1000 result.total_time_ms = (time.time() - start_time) * 1000 self.logger.info( f"Hierarchical retrieval: {result.tier1_candidates} KBs -> " f"{result.tier2_candidates} docs -> {result.tier3_candidates} chunks " f"in {result.total_time_ms:.2f}ms" ) return result def _tier1_kb_routing( self, query: str, kb_ids: List[str], kb_infos: Optional[List[Dict[str, Any]]] = None ) -> List[str]: """ Tier 1: Route query to relevant knowledge bases. Uses query-KB similarity scoring to select most relevant KBs. """ if not self.config.enable_kb_routing: return kb_ids if not kb_ids: return [] method = self.config.kb_routing_method if method == "all": return kb_ids elif method == "rule_based": return self._rule_based_routing(query, kb_ids, kb_infos) elif method == "llm_based": return self._llm_based_routing(query, kb_ids, kb_infos) else: # "auto" return self._auto_routing(query, kb_ids, kb_infos) def _rule_based_routing( self, query: str, kb_ids: List[str], kb_infos: Optional[List[Dict[str, Any]]] = None ) -> List[str]: """ Rule-based KB routing using keyword matching. Matches query keywords against KB names/descriptions. """ if not kb_infos: # No KB metadata available, return all return kb_ids # Tokenize query query_tokens = set(rag_tokenizer.tokenize(query.lower())) # Score each KB based on keyword overlap kb_scores = [] for kb_info in kb_infos: kb_id = kb_info.get('id') if kb_id not in kb_ids: continue # Get KB name and description kb_text = ' '.join([ kb_info.get('name', ''), kb_info.get('description', '') ]).lower() kb_tokens = set(rag_tokenizer.tokenize(kb_text)) # Calculate overlap score if kb_tokens: overlap = len(query_tokens & kb_tokens) score = overlap / len(query_tokens) if query_tokens else 0 kb_scores.append((kb_id, score)) # Filter by threshold and select top K filtered_kbs = [ kb_id for kb_id, score in kb_scores if score >= self.config.kb_routing_threshold ] if not filtered_kbs: # If no KB passes threshold, return top K by score kb_scores.sort(key=lambda x: x[1], reverse=True) filtered_kbs = [kb_id for kb_id, _ in kb_scores[:self.config.kb_top_k]] return filtered_kbs[:self.config.kb_top_k] if filtered_kbs else kb_ids def _llm_based_routing( self, query: str, kb_ids: List[str], kb_infos: Optional[List[Dict[str, Any]]] = None ) -> List[str]: """ LLM-based KB routing using semantic understanding. Uses LLM to understand query intent and match to KBs. Falls back to rule-based if LLM unavailable. """ # For now, fall back to rule-based routing # Full LLM integration would require: # 1. LLM service integration # 2. Prompt engineering for KB selection # 3. Structured output parsing self.logger.info("LLM-based routing not yet implemented, falling back to rule-based") return self._rule_based_routing(query, kb_ids, kb_infos) def _auto_routing( self, query: str, kb_ids: List[str], kb_infos: Optional[List[Dict[str, Any]]] = None ) -> List[str]: """ Auto routing combines multiple strategies. Uses query analysis to automatically select best routing method. """ # Analyze query characteristics query_len = len(query.split()) # For short queries, use rule-based (faster) if query_len < 5: return self._rule_based_routing(query, kb_ids, kb_infos) # For longer queries, use rule-based with higher threshold # Could be extended to use LLM for complex queries return self._rule_based_routing(query, kb_ids, kb_infos) def _tier2_document_filtering( self, query: str, kb_ids: List[str], filters: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ Tier 2: Filter documents by metadata. Applies metadata filters to reduce document set before vector search. """ if not self.config.enable_doc_filtering or not self.datastore: return [] # Build filter conditions conditions = {'kb_id': kb_ids} # Add user-provided filters if filters: for key in self.config.metadata_fields: if key in filters: conditions[key] = filters[key] try: # Query documents with filters # This is a simplified version - actual implementation would use # RAGFlow's search infrastructure filtered_docs = [] # Get document IDs that match filters # In production, this would query the document service # For now, return empty to proceed to tier 3 return filtered_docs except Exception as e: self.logger.error(f"Document filtering error: {e}") return [] def _tier3_chunk_refinement( self, query: str, kb_ids: List[str], filtered_docs: List[Dict[str, Any]], top_k: int ) -> List[Dict[str, Any]]: """ Tier 3: Perform precise chunk-level retrieval. Uses vector search within selected KBs (and optionally filtered docs). """ if not self.datastore or not self.embedding_model: self.logger.warning("Datastore or embedding model not available") return [] try: from rag.nlp.search import Dealer # Create search dealer dealer = Dealer(self.datastore) # Build search request req = { 'question': query, 'kb_ids': kb_ids, 'size': top_k, 'similarity_threshold': self.config.similarity_threshold, 'vector_similarity_weight': self.config.vector_weight, } # Add document filter if we have filtered docs if filtered_docs: doc_ids = [doc.get('id') or doc.get('doc_id') for doc in filtered_docs] req['doc_ids'] = [did for did in doc_ids if did] # Get index names for KBs idx_names = [index_name(kb_id) for kb_id in kb_ids] # Perform search search_result = dealer.search( req=req, idx_names=idx_names, kb_ids=kb_ids, emb_mdl=self.embedding_model, highlight=True ) # Extract chunks from search results chunks = [] if search_result and hasattr(search_result, 'field'): for doc_id, fields in (search_result.field or {}).items(): chunk = { 'id': doc_id, 'content': fields.get('content_ltks', ''), 'doc_id': fields.get('doc_id', ''), 'kb_id': fields.get('kb_id', ''), 'doc_name': fields.get('docnm_kwd', ''), 'page_num': fields.get('page_num_int'), 'position': fields.get('position_int'), } # Add highlight if available if hasattr(search_result, 'highlight') and search_result.highlight: chunk['highlight'] = search_result.highlight.get(doc_id, {}) chunks.append(chunk) return chunks[:top_k] except Exception as e: self.logger.error(f"Chunk refinement error: {e}") return [] class KBRouter: """ Knowledge Base Router for Tier 1. Standalone router for KB selection based on query intent. """ def __init__(self): self.logger = logging.getLogger(__name__) def route( self, query: str, available_kbs: List[Dict[str, Any]], method: str = "auto", threshold: float = 0.3, top_k: int = 3 ) -> List[str]: """ Route query to relevant knowledge bases. Args: query: User query available_kbs: List of KB metadata dicts method: Routing method threshold: Minimum score threshold top_k: Max KBs to return Returns: List of selected KB IDs """ if not available_kbs: return [] # Tokenize query query_tokens = set(rag_tokenizer.tokenize(query.lower())) # Score each KB kb_scores = [] for kb in available_kbs: kb_text = ' '.join([ kb.get('name', ''), kb.get('description', '') ]).lower() kb_tokens = set(rag_tokenizer.tokenize(kb_text)) if kb_tokens and query_tokens: overlap = len(query_tokens & kb_tokens) score = overlap / len(query_tokens) kb_scores.append((kb['id'], score)) # Filter and sort kb_scores = [(kid, s) for kid, s in kb_scores if s >= threshold] kb_scores.sort(key=lambda x: x[1], reverse=True) selected = [kb_id for kb_id, _ in kb_scores[:top_k]] # If none pass threshold, return top K anyway if not selected and kb_scores: selected = [kb_id for kb_id, _ in kb_scores[:top_k]] return selected if selected else [kb['id'] for kb in available_kbs[:top_k]] class DocumentFilter: """ Document Filter for Tier 2. Filters documents based on metadata before vector search. """ def __init__(self): self.logger = logging.getLogger(__name__) def filter( self, query: str, documents: List[Dict[str, Any]], metadata_fields: List[str], filters: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ Filter documents by metadata. Args: query: User query (for query-aware filtering) documents: List of documents to filter metadata_fields: Metadata fields to consider filters: Explicit filter values Returns: Filtered documents """ if not documents: return [] filtered = documents # Apply explicit filters if filters: for field, value in filters.items(): if field in metadata_fields: filtered = [ doc for doc in filtered if doc.get(field) == value ] return filtered class ChunkRefiner: """ Chunk Refiner for Tier 3. Performs precise chunk-level retrieval using vector search. """ def __init__(self, datastore=None, embedding_model=None): self.logger = logging.getLogger(__name__) self.datastore = datastore self.embedding_model = embedding_model def refine( self, query: str, kb_ids: List[str], doc_ids: Optional[List[str]] = None, top_k: int = 10, similarity_threshold: float = 0.2 ) -> List[Dict[str, Any]]: """ Perform precise chunk retrieval. Args: query: User query kb_ids: KB IDs to search within doc_ids: Optional document IDs to limit search top_k: Number of chunks to return similarity_threshold: Minimum similarity score Returns: List of retrieved chunks """ if not self.datastore or not self.embedding_model: return [] try: from rag.nlp.search import Dealer, index_name dealer = Dealer(self.datastore) req = { 'question': query, 'kb_ids': kb_ids, 'size': top_k, 'similarity_threshold': similarity_threshold, } if doc_ids: req['doc_ids'] = doc_ids idx_names = [index_name(kb_id) for kb_id in kb_ids] result = dealer.search( req=req, idx_names=idx_names, kb_ids=kb_ids, emb_mdl=self.embedding_model ) chunks = [] if result and hasattr(result, 'field'): for doc_id, fields in (result.field or {}).items(): chunks.append({ 'id': doc_id, 'content': fields.get('content_ltks', ''), 'score': fields.get('score', 0.0), }) return chunks[:top_k] except Exception as e: self.logger.error(f"Chunk refinement failed: {e}") return []