diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 52e295fd0..ec0404f43 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -63,6 +63,7 @@ class RetrievalParam(ToolParamBase): self.cross_languages = [] self.toc_enhance = False self.meta_data_filter={} + self.hierarchical_retrieval = False # Enable hierarchical retrieval def check(self): self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold") @@ -174,20 +175,42 @@ class Retrieval(ToolBase, ABC): if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) - kbinfos = settings.retriever.retrieval( - query, - embd_mdl, - [kb.tenant_id for kb in kbs], - filtered_kb_ids, - 1, - self._param.top_n, - self._param.similarity_threshold, - 1 - self._param.keywords_similarity_weight, - doc_ids=doc_ids, - aggs=False, - rerank_mdl=rerank_mdl, - rank_feature=label_question(query, kbs), - ) + + # Use hierarchical retrieval if enabled + if self._param.hierarchical_retrieval: + from rag.nlp.search import HierarchicalConfig + kb_infos = [{"id": kb.id, "name": kb.name, "description": kb.description or ""} for kb in kbs] + kbinfos = settings.retriever.hierarchical_retrieval( + question=query, + embd_mdl=embd_mdl, + tenant_ids=[kb.tenant_id for kb in kbs], + kb_ids=filtered_kb_ids, + kb_infos=kb_infos, + page=1, + page_size=self._param.top_n, + similarity_threshold=self._param.similarity_threshold, + vector_similarity_weight=1 - self._param.keywords_similarity_weight, + doc_ids=doc_ids, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(query, kbs), + hierarchical_config=HierarchicalConfig(enabled=True), + ) + else: + kbinfos = settings.retriever.retrieval( + query, + embd_mdl, + [kb.tenant_id for kb in kbs], + filtered_kb_ids, + 1, + self._param.top_n, + self._param.similarity_threshold, + 1 - self._param.keywords_similarity_weight, + doc_ids=doc_ids, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(query, kbs), + ) if self.check_if_canceled("Retrieval processing"): return diff --git a/api/apps/metadata_app.py b/api/apps/metadata_app.py new file mode 100644 index 000000000..17204c97f --- /dev/null +++ b/api/apps/metadata_app.py @@ -0,0 +1,378 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Metadata Management API for Hierarchical Retrieval. + +Provides REST endpoints for batch CRUD operations on document metadata, +supporting the hierarchical retrieval architecture's Tier 2 document filtering. +""" + +from quart import request +from api.apps import current_user, login_required +from api.common.check_team_permission import check_kb_team_permission +from api.db.services.metadata_service import MetadataService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.utils.api_utils import ( + get_json_result, + server_error_response, +) +from common.constants import RetCode + + +@manager.route("/batch/get", methods=["POST"]) # noqa: F821 +@login_required +async def batch_get_metadata(): + """ + Get metadata for multiple documents. + + Request body: + { + "doc_ids": ["doc1", "doc2", ...], + "fields": ["field1", "field2", ...] // optional + } + + Returns: + { + "doc1": {"doc_id": "doc1", "doc_name": "...", "metadata": {...}}, + ... + } + """ + try: + req = await request.json + doc_ids = req.get("doc_ids", []) + fields = req.get("fields") + + if not doc_ids: + return get_json_result( + data={}, + message="No document IDs provided", + code=RetCode.ARGUMENT_ERROR + ) + + result = MetadataService.batch_get_metadata(doc_ids, fields) + return get_json_result(data=result) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/batch/update", methods=["POST"]) # noqa: F821 +@login_required +async def batch_update_metadata(): + """ + Update metadata for multiple documents. + + Request body: + { + "updates": [ + {"doc_id": "doc1", "metadata": {"field1": "value1", ...}}, + {"doc_id": "doc2", "metadata": {"field2": "value2", ...}}, + ... + ], + "merge": true // optional, default true. If false, replaces all metadata + } + + Returns: + { + "success_count": 5, + "failed_ids": ["doc3"] + } + """ + try: + req = await request.json + updates = req.get("updates", []) + merge = req.get("merge", True) + + if not updates: + return get_json_result( + data={"success_count": 0, "failed_ids": []}, + message="No updates provided", + code=RetCode.ARGUMENT_ERROR + ) + + success_count, failed_ids = MetadataService.batch_update_metadata(updates, merge) + + return get_json_result(data={ + "success_count": success_count, + "failed_ids": failed_ids + }) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/batch/delete-fields", methods=["POST"]) # noqa: F821 +@login_required +async def batch_delete_metadata_fields(): + """ + Delete specific metadata fields from multiple documents. + + Request body: + { + "doc_ids": ["doc1", "doc2", ...], + "fields": ["field1", "field2", ...] + } + + Returns: + { + "success_count": 5, + "failed_ids": [] + } + """ + try: + req = await request.json + doc_ids = req.get("doc_ids", []) + fields = req.get("fields", []) + + if not doc_ids or not fields: + return get_json_result( + data={"success_count": 0, "failed_ids": []}, + message="doc_ids and fields are required", + code=RetCode.ARGUMENT_ERROR + ) + + success_count, failed_ids = MetadataService.batch_delete_metadata_fields(doc_ids, fields) + + return get_json_result(data={ + "success_count": success_count, + "failed_ids": failed_ids + }) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/batch/set-field", methods=["POST"]) # noqa: F821 +@login_required +async def batch_set_metadata_field(): + """ + Set a specific field to the same value for multiple documents. + + Useful for bulk categorization or tagging. + + Request body: + { + "doc_ids": ["doc1", "doc2", ...], + "field_name": "category", + "field_value": "Technical" + } + + Returns: + { + "success_count": 5, + "failed_ids": [] + } + """ + try: + req = await request.json + doc_ids = req.get("doc_ids", []) + field_name = req.get("field_name") + field_value = req.get("field_value") + + if not doc_ids or not field_name: + return get_json_result( + data={"success_count": 0, "failed_ids": []}, + message="doc_ids and field_name are required", + code=RetCode.ARGUMENT_ERROR + ) + + success_count, failed_ids = MetadataService.batch_set_metadata_field( + doc_ids, field_name, field_value + ) + + return get_json_result(data={ + "success_count": success_count, + "failed_ids": failed_ids + }) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/schema/", methods=["GET"]) # noqa: F821 +@login_required +async def get_metadata_schema(kb_id): + """ + Get the metadata schema for a knowledge base. + + Returns available metadata fields, their types, and sample values. + + Returns: + { + "field1": {"type": "str", "sample_values": ["a", "b"], "count": 10}, + ... + } + """ + try: + # Check KB access permission + kb = KnowledgebaseService.get_by_id(kb_id) + if not kb: + return get_json_result( + data={}, + message="Knowledge base not found", + code=RetCode.DATA_ERROR + ) + + if not check_kb_team_permission(current_user.id, kb_id): + return get_json_result( + data={}, + message="No permission to access this knowledge base", + code=RetCode.PERMISSION_ERROR + ) + + schema = MetadataService.get_metadata_schema(kb_id) + return get_json_result(data=schema) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/statistics/", methods=["GET"]) # noqa: F821 +@login_required +async def get_metadata_statistics(kb_id): + """ + Get statistics about metadata usage in a knowledge base. + + Returns: + { + "total_documents": 100, + "documents_with_metadata": 80, + "metadata_coverage": 0.8, + "field_usage": {"category": 50, "author": 30}, + "unique_fields": 5 + } + """ + try: + # Check KB access permission + kb = KnowledgebaseService.get_by_id(kb_id) + if not kb: + return get_json_result( + data={}, + message="Knowledge base not found", + code=RetCode.DATA_ERROR + ) + + if not check_kb_team_permission(current_user.id, kb_id): + return get_json_result( + data={}, + message="No permission to access this knowledge base", + code=RetCode.PERMISSION_ERROR + ) + + stats = MetadataService.get_metadata_statistics(kb_id) + return get_json_result(data=stats) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/search", methods=["POST"]) # noqa: F821 +@login_required +async def search_by_metadata(): + """ + Search documents by metadata filters. + + Request body: + { + "kb_id": "kb123", + "filters": { + "category": "Technical", + "author": {"contains": "John"}, + "year": {"gt": 2020} + }, + "limit": 100 + } + + Supported operators: equals, contains, starts_with, in, gt, lt + + Returns: + [ + {"doc_id": "doc1", "doc_name": "...", "metadata": {...}}, + ... + ] + """ + try: + req = await request.json + kb_id = req.get("kb_id") + filters = req.get("filters", {}) + limit = req.get("limit", 100) + + if not kb_id: + return get_json_result( + data=[], + message="kb_id is required", + code=RetCode.ARGUMENT_ERROR + ) + + # Check KB access permission + if not check_kb_team_permission(current_user.id, kb_id): + return get_json_result( + data=[], + message="No permission to access this knowledge base", + code=RetCode.PERMISSION_ERROR + ) + + results = MetadataService.search_by_metadata(kb_id, filters, limit) + return get_json_result(data=results) + + except Exception as e: + return server_error_response(e) + + +@manager.route("/copy", methods=["POST"]) # noqa: F821 +@login_required +async def copy_metadata(): + """ + Copy metadata from one document to multiple target documents. + + Request body: + { + "source_doc_id": "doc1", + "target_doc_ids": ["doc2", "doc3", ...], + "fields": ["field1", "field2"] // optional, copies all if not specified + } + + Returns: + { + "success_count": 5, + "failed_ids": [] + } + """ + try: + req = await request.json + source_doc_id = req.get("source_doc_id") + target_doc_ids = req.get("target_doc_ids", []) + fields = req.get("fields") + + if not source_doc_id or not target_doc_ids: + return get_json_result( + data={"success_count": 0, "failed_ids": []}, + message="source_doc_id and target_doc_ids are required", + code=RetCode.ARGUMENT_ERROR + ) + + success_count, failed_ids = MetadataService.copy_metadata( + source_doc_id, target_doc_ids, fields + ) + + return get_json_result(data={ + "success_count": success_count, + "failed_ids": failed_ids + }) + + except Exception as e: + return server_error_response(e) diff --git a/api/db/services/metadata_service.py b/api/db/services/metadata_service.py new file mode 100644 index 000000000..4a3f602ae --- /dev/null +++ b/api/db/services/metadata_service.py @@ -0,0 +1,398 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Metadata Management Service for Hierarchical Retrieval. + +Provides batch CRUD operations for document metadata to support: +- Efficient metadata filtering in Tier 2 of hierarchical retrieval +- Bulk metadata updates across multiple documents +- Metadata schema management per knowledge base +""" + +import logging +from typing import List, Dict, Any, Optional, Tuple + +from peewee import fn + +from api.db.db_models import DB, Document +from api.db.services.document_service import DocumentService + + +class MetadataService: + """ + Service for managing document metadata in batch operations. + + Supports the hierarchical retrieval architecture by providing + efficient metadata management for document filtering. + """ + + @classmethod + @DB.connection_context() + def batch_get_metadata( + cls, + doc_ids: List[str], + fields: Optional[List[str]] = None + ) -> Dict[str, Dict[str, Any]]: + """ + Get metadata for multiple documents. + + Args: + doc_ids: List of document IDs + fields: Optional list of specific metadata fields to retrieve + + Returns: + Dict mapping doc_id to metadata dict + """ + if not doc_ids: + return {} + + result = {} + docs = Document.select( + Document.id, + Document.meta_fields, + Document.name + ).where(Document.id.in_(doc_ids)) + + for doc in docs: + meta = doc.meta_fields or {} + if fields: + # Filter to requested fields only + meta = {k: v for k, v in meta.items() if k in fields} + result[doc.id] = { + "doc_id": doc.id, + "doc_name": doc.name, + "metadata": meta + } + + return result + + @classmethod + @DB.connection_context() + def batch_update_metadata( + cls, + updates: List[Dict[str, Any]], + merge: bool = True + ) -> Tuple[int, List[str]]: + """ + Update metadata for multiple documents in batch. + + Args: + updates: List of dicts with 'doc_id' and 'metadata' keys + merge: If True, merge with existing metadata; if False, replace + + Returns: + Tuple of (success_count, list of failed doc_ids) + """ + if not updates: + return 0, [] + + success_count = 0 + failed_ids = [] + + for update in updates: + doc_id = update.get("doc_id") + new_metadata = update.get("metadata", {}) + + if not doc_id: + continue + + try: + if merge: + # Get existing metadata and merge + doc = Document.get_or_none(Document.id == doc_id) + if doc: + existing = doc.meta_fields or {} + existing.update(new_metadata) + new_metadata = existing + + DocumentService.update_meta_fields(doc_id, new_metadata) + success_count += 1 + + except Exception as e: + logging.error(f"Failed to update metadata for doc {doc_id}: {e}") + failed_ids.append(doc_id) + + logging.info(f"Batch metadata update: {success_count} succeeded, {len(failed_ids)} failed") + return success_count, failed_ids + + @classmethod + @DB.connection_context() + def batch_delete_metadata_fields( + cls, + doc_ids: List[str], + fields: List[str] + ) -> Tuple[int, List[str]]: + """ + Delete specific metadata fields from multiple documents. + + Args: + doc_ids: List of document IDs + fields: List of metadata field names to delete + + Returns: + Tuple of (success_count, list of failed doc_ids) + """ + if not doc_ids or not fields: + return 0, [] + + success_count = 0 + failed_ids = [] + + docs = Document.select( + Document.id, + Document.meta_fields + ).where(Document.id.in_(doc_ids)) + + for doc in docs: + try: + meta = doc.meta_fields or {} + modified = False + + for field in fields: + if field in meta: + del meta[field] + modified = True + + if modified: + DocumentService.update_meta_fields(doc.id, meta) + success_count += 1 + + except Exception as e: + logging.error(f"Failed to delete metadata fields for doc {doc.id}: {e}") + failed_ids.append(doc.id) + + return success_count, failed_ids + + @classmethod + @DB.connection_context() + def batch_set_metadata_field( + cls, + doc_ids: List[str], + field_name: str, + field_value: Any + ) -> Tuple[int, List[str]]: + """ + Set a specific metadata field to the same value for multiple documents. + + Useful for bulk categorization or tagging. + + Args: + doc_ids: List of document IDs + field_name: Name of the metadata field + field_value: Value to set + + Returns: + Tuple of (success_count, list of failed doc_ids) + """ + if not doc_ids or not field_name: + return 0, [] + + updates = [ + {"doc_id": doc_id, "metadata": {field_name: field_value}} + for doc_id in doc_ids + ] + + return cls.batch_update_metadata(updates, merge=True) + + @classmethod + @DB.connection_context() + def get_metadata_schema(cls, kb_id: str) -> Dict[str, Dict[str, Any]]: + """ + Get the metadata schema for a knowledge base. + + Analyzes all documents in the KB to determine available + metadata fields and their types/values. + + Args: + kb_id: Knowledge base ID + + Returns: + Dict mapping field names to field info (type, sample values, count) + """ + schema = {} + + docs = Document.select( + Document.meta_fields + ).where(Document.kb_id == kb_id) + + for doc in docs: + meta = doc.meta_fields or {} + for field_name, field_value in meta.items(): + if field_name not in schema: + schema[field_name] = { + "type": type(field_value).__name__, + "sample_values": set(), + "count": 0 + } + + schema[field_name]["count"] += 1 + + # Collect sample values (limit to 10) + if len(schema[field_name]["sample_values"]) < 10: + try: + schema[field_name]["sample_values"].add(str(field_value)[:100]) + except Exception: + pass + + # Convert sets to lists for JSON serialization + for field_name in schema: + schema[field_name]["sample_values"] = list(schema[field_name]["sample_values"]) + + return schema + + @classmethod + @DB.connection_context() + def search_by_metadata( + cls, + kb_id: str, + filters: Dict[str, Any], + limit: int = 100 + ) -> List[Dict[str, Any]]: + """ + Search documents by metadata filters. + + Args: + kb_id: Knowledge base ID + filters: Dict of field_name -> value or {operator: value} + limit: Maximum number of results + + Returns: + List of matching documents with their metadata + """ + docs = Document.select( + Document.id, + Document.name, + Document.meta_fields + ).where(Document.kb_id == kb_id) + + results = [] + for doc in docs: + meta = doc.meta_fields or {} + matches = True + + for field_name, condition in filters.items(): + doc_value = meta.get(field_name) + + if isinstance(condition, dict): + # Operator-based condition + op = list(condition.keys())[0] + val = condition[op] + + if op == "equals": + matches = str(doc_value) == str(val) + elif op == "contains": + matches = str(val).lower() in str(doc_value).lower() + elif op == "starts_with": + matches = str(doc_value).lower().startswith(str(val).lower()) + elif op == "in": + matches = doc_value in val + elif op == "gt": + matches = float(doc_value) > float(val) if doc_value else False + elif op == "lt": + matches = float(doc_value) < float(val) if doc_value else False + else: + # Simple equality + matches = str(doc_value) == str(condition) + + if not matches: + break + + if matches: + results.append({ + "doc_id": doc.id, + "doc_name": doc.name, + "metadata": meta + }) + + if len(results) >= limit: + break + + return results + + @classmethod + @DB.connection_context() + def get_metadata_statistics(cls, kb_id: str) -> Dict[str, Any]: + """ + Get statistics about metadata usage in a knowledge base. + + Args: + kb_id: Knowledge base ID + + Returns: + Dict with statistics about metadata fields + """ + total_docs = Document.select(fn.COUNT(Document.id)).where( + Document.kb_id == kb_id + ).scalar() + + docs_with_metadata = 0 + field_usage = {} + + docs = Document.select(Document.meta_fields).where(Document.kb_id == kb_id) + + for doc in docs: + meta = doc.meta_fields or {} + if meta: + docs_with_metadata += 1 + for field_name in meta.keys(): + field_usage[field_name] = field_usage.get(field_name, 0) + 1 + + return { + "total_documents": total_docs, + "documents_with_metadata": docs_with_metadata, + "metadata_coverage": docs_with_metadata / total_docs if total_docs > 0 else 0, + "field_usage": field_usage, + "unique_fields": len(field_usage) + } + + @classmethod + @DB.connection_context() + def copy_metadata( + cls, + source_doc_id: str, + target_doc_ids: List[str], + fields: Optional[List[str]] = None + ) -> Tuple[int, List[str]]: + """ + Copy metadata from one document to multiple target documents. + + Args: + source_doc_id: Source document ID + target_doc_ids: List of target document IDs + fields: Optional list of specific fields to copy (all if None) + + Returns: + Tuple of (success_count, list of failed doc_ids) + """ + source_doc = Document.get_or_none(Document.id == source_doc_id) + if not source_doc: + return 0, target_doc_ids + + source_meta = source_doc.meta_fields or {} + + if fields: + source_meta = {k: v for k, v in source_meta.items() if k in fields} + + if not source_meta: + return 0, [] + + updates = [ + {"doc_id": doc_id, "metadata": source_meta.copy()} + for doc_id in target_doc_ids + ] + + return cls.batch_update_metadata(updates, merge=True) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index f5dd2d4de..908a8d95f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -17,8 +17,10 @@ import json import logging import re import math +import time from collections import OrderedDict, defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import List, Dict, Any, Optional from rag.prompts.generator import relevant_chunks_with_toc from rag.nlp import rag_tokenizer, query @@ -33,6 +35,69 @@ from common import settings def index_name(uid): return f"ragflow_{uid}" +@dataclass +class KBRetrievalParams: + """Per-KB retrieval parameters for independent configuration.""" + kb_id: str + vector_similarity_weight: float = 0.7 # Weight for vector vs keyword + similarity_threshold: float = 0.2 + top_k: int = 1024 + rerank_enabled: bool = True + + +@dataclass +class HierarchicalConfig: + """Configuration for hierarchical retrieval. + + Hierarchical retrieval uses a three-tier approach: + 1. KB Routing: Select relevant knowledge bases based on query + 2. Document Filtering: Filter documents by metadata before vector search + 3. Chunk Refinement: Precise vector search within filtered scope + """ + + # Enable hierarchical retrieval + enabled: bool = False + + # Tier 1: KB Routing + enable_kb_routing: bool = True + kb_routing_method: str = "auto" # "auto", "rule_based", "llm_based", "all" + kb_routing_threshold: float = 0.3 # Min keyword overlap score + kb_top_k: int = 3 # Max KBs to select + kb_params: Dict[str, KBRetrievalParams] = field(default_factory=dict) # Per-KB params + + # Tier 2: Document Filtering + enable_doc_filtering: bool = True + doc_top_k: int = 100 # Max documents to pass to tier 3 + metadata_fields: List[str] = field(default_factory=list) # Fields for filtering + enable_metadata_similarity: bool = False # Fuzzy matching for text metadata + metadata_similarity_threshold: float = 0.7 + use_llm_metadata_filter: bool = False # LLM-generated filter conditions + + # Tier 3: Chunk Refinement + chunk_top_k: int = 10 + enable_parent_child: bool = False # Parent-child chunking with summary mapping + use_summary_mapping: bool = False # Match via summary vectors first + + # Customizable prompts + keyword_extraction_prompt: Optional[str] = None + question_generation_prompt: Optional[str] = None + + # LLM-based enhancements + use_llm_question_generation: bool = False # Generate questions from chunks + + +@dataclass +class HierarchicalResult: + """Result metadata from hierarchical retrieval.""" + + selected_kb_ids: List[str] = field(default_factory=list) + filtered_doc_ids: List[str] = field(default_factory=list) + tier1_time_ms: float = 0.0 + tier2_time_ms: float = 0.0 + tier3_time_ms: float = 0.0 + total_time_ms: float = 0.0 + + class Dealer: def __init__(self, dataStore: DocStoreConnection): self.qryr = query.FulltextQueryer() @@ -504,6 +569,995 @@ class Dealer: return ranks + def hierarchical_retrieval( + self, + question: str, + embd_mdl, + tenant_ids, + kb_ids: List[str], + kb_infos: Optional[List[Dict[str, Any]]] = None, + page: int = 1, + page_size: int = 10, + similarity_threshold: float = 0.2, + vector_similarity_weight: float = 0.3, + top: int = 1024, + doc_ids: Optional[List[str]] = None, + aggs: bool = True, + rerank_mdl=None, + highlight: bool = False, + rank_feature: Optional[dict] = None, + hierarchical_config: Optional[HierarchicalConfig] = None, + chat_mdl=None, + doc_metadata: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """ + Perform hierarchical retrieval with three tiers: + 1. KB Routing: Select relevant knowledge bases + 2. Document Filtering: Filter by document metadata + 3. Chunk Refinement: Precise vector search + + Args: + question: User query + embd_mdl: Embedding model + tenant_ids: Tenant IDs + kb_ids: Knowledge base IDs to search + kb_infos: Optional KB metadata (name, description) for routing + page: Page number + page_size: Results per page + similarity_threshold: Minimum similarity score + vector_similarity_weight: Weight for vector vs keyword similarity + top: Maximum results to consider + doc_ids: Optional document IDs to filter + aggs: Whether to include aggregations + rerank_mdl: Optional reranking model + highlight: Whether to include highlights + rank_feature: Ranking features + hierarchical_config: Hierarchical retrieval configuration + chat_mdl: Optional chat model for LLM-based features + doc_metadata: Optional document metadata for filtering + + Returns: + Dict with chunks, doc_aggs, and hierarchical_metadata + """ + if rank_feature is None: + rank_feature = {PAGERANK_FLD: 10} + + config = hierarchical_config or HierarchicalConfig() + start_time = time.time() + + h_result = HierarchicalResult() + + # Tier 1: KB Routing + tier1_start = time.time() + selected_kb_ids = self._tier1_kb_routing( + question, kb_ids, kb_infos, config, chat_mdl + ) + h_result.selected_kb_ids = selected_kb_ids + h_result.tier1_time_ms = (time.time() - tier1_start) * 1000 + + if not selected_kb_ids: + logging.warning(f"Hierarchical retrieval: No KBs selected for query: {question[:50]}...") + return { + "total": 0, + "chunks": [], + "doc_aggs": [], + "hierarchical_metadata": h_result + } + + logging.info(f"Tier 1: Selected {len(selected_kb_ids)}/{len(kb_ids)} KBs") + + # Get per-KB parameters for selected KBs + kb_specific_params = { + kb_id: config.kb_params.get(kb_id) + for kb_id in selected_kb_ids + if kb_id in config.kb_params + } + if kb_specific_params: + logging.info(f"Using per-KB params for: {list(kb_specific_params.keys())}") + + # Tier 2: Document Filtering + tier2_start = time.time() + filtered_doc_ids = self._tier2_document_filtering( + question, tenant_ids, selected_kb_ids, doc_ids, config, embd_mdl, + chat_mdl, doc_metadata + ) + h_result.filtered_doc_ids = filtered_doc_ids + h_result.tier2_time_ms = (time.time() - tier2_start) * 1000 + + if filtered_doc_ids: + logging.info(f"Tier 2: Filtered to {len(filtered_doc_ids)} documents") + + # Tier 3: Chunk Refinement + tier3_start = time.time() + + # Use filtered doc_ids if available, otherwise use original + effective_doc_ids = filtered_doc_ids if filtered_doc_ids else doc_ids + + # Parent-child chunking with summary mapping + if config.enable_parent_child and config.use_summary_mapping: + ranks = self._tier3_summary_mapping_retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=selected_kb_ids, + doc_ids=effective_doc_ids, + page=page, + page_size=page_size, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + aggs=aggs, + rerank_mdl=rerank_mdl, + highlight=highlight, + rank_feature=rank_feature, + config=config, + ) + else: + # Standard retrieval + ranks = self.retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=selected_kb_ids, # Use selected KBs from Tier 1 + page=page, + page_size=page_size, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=effective_doc_ids, # Use filtered docs from Tier 2 + aggs=aggs, + rerank_mdl=rerank_mdl, + highlight=highlight, + rank_feature=rank_feature, + ) + + # Apply customizable prompts for keyword extraction if configured + if config.keyword_extraction_prompt and ranks.get("chunks"): + ranks["chunks"] = self._apply_custom_keyword_extraction( + ranks["chunks"], config.keyword_extraction_prompt, embd_mdl + ) + + # Apply LLM-based question generation if configured + if config.use_llm_question_generation and ranks.get("chunks") and chat_mdl: + ranks["chunks"] = self._apply_llm_question_generation( + ranks["chunks"], config.question_generation_prompt, chat_mdl + ) + + h_result.tier3_time_ms = (time.time() - tier3_start) * 1000 + h_result.total_time_ms = (time.time() - start_time) * 1000 + + # Add hierarchical metadata to result + ranks["hierarchical_metadata"] = { + "selected_kb_ids": h_result.selected_kb_ids, + "filtered_doc_ids": h_result.filtered_doc_ids, + "tier1_time_ms": h_result.tier1_time_ms, + "tier2_time_ms": h_result.tier2_time_ms, + "tier3_time_ms": h_result.tier3_time_ms, + "total_time_ms": h_result.total_time_ms, + "kb_params_used": {kb_id: True for kb_id in selected_kb_ids if kb_id in config.kb_params}, + } + + logging.info( + f"Hierarchical retrieval: {len(selected_kb_ids)} KBs -> " + f"{len(filtered_doc_ids) if filtered_doc_ids else 'all'} docs -> " + f"{len(ranks.get('chunks', []))} chunks in {h_result.total_time_ms:.1f}ms" + ) + + return ranks + + def _tier1_kb_routing( + self, + question: str, + kb_ids: List[str], + kb_infos: Optional[List[Dict[str, Any]]], + config: HierarchicalConfig, + chat_mdl=None + ) -> List[str]: + """ + Tier 1: Route query to relevant knowledge bases. + + Supports multiple routing methods: + - "all": Return all KBs (no filtering) + - "rule_based": Keyword overlap matching + - "llm_based": LLM-powered intent analysis + - "auto": Combines rule-based with LLM fallback + """ + if not config.enable_kb_routing or not kb_infos: + return kb_ids + + if len(kb_ids) <= config.kb_top_k: + return kb_ids + + method = config.kb_routing_method + + if method == "all": + return kb_ids + elif method == "llm_based": + return self._llm_kb_routing(question, kb_ids, kb_infos, config, chat_mdl) + elif method == "rule_based": + return self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + else: # "auto" - try rule-based first, fall back to LLM if needed + rule_result = self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + # If rule-based found good matches, use them + if rule_result and len(rule_result) < len(kb_ids): + return rule_result + # Otherwise try LLM if available + if chat_mdl: + return self._llm_kb_routing(question, kb_ids, kb_infos, config, chat_mdl) + return rule_result + + def _rule_based_kb_routing( + self, + question: str, + kb_ids: List[str], + kb_infos: List[Dict[str, Any]], + config: HierarchicalConfig + ) -> List[str]: + """Rule-based KB routing using keyword overlap.""" + # Tokenize query + query_tokens = set(rag_tokenizer.tokenize(question.lower()).split()) + if not query_tokens: + return kb_ids + + # Score each KB + kb_scores = [] + for kb_info in kb_infos: + kb_id = kb_info.get('id') + if kb_id not in kb_ids: + continue + + # Combine name and description + kb_text = ' '.join([ + kb_info.get('name', ''), + kb_info.get('description', '') + ]).lower() + + kb_tokens = set(rag_tokenizer.tokenize(kb_text).split()) + + if kb_tokens: + overlap = len(query_tokens & kb_tokens) + score = overlap / len(query_tokens) + kb_scores.append((kb_id, score)) + + if not kb_scores: + return kb_ids + + # Sort by score descending + kb_scores.sort(key=lambda x: x[1], reverse=True) + + # Filter by threshold + selected = [ + kb_id for kb_id, score in kb_scores + if score >= config.kb_routing_threshold + ] + + # If none pass threshold, take top K + if not selected: + selected = [kb_id for kb_id, _ in kb_scores[:config.kb_top_k]] + + return selected[:config.kb_top_k] if selected else kb_ids + + def _llm_kb_routing( + self, + question: str, + kb_ids: List[str], + kb_infos: List[Dict[str, Any]], + config: HierarchicalConfig, + chat_mdl=None + ) -> List[str]: + """LLM-based KB routing using semantic understanding.""" + if not chat_mdl: + logging.warning("LLM routing requested but no chat model provided, falling back to rule-based") + return self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + + try: + # Build KB descriptions for LLM + kb_descriptions = [] + kb_id_map = {} + for i, kb_info in enumerate(kb_infos): + kb_id = kb_info.get('id') + if kb_id not in kb_ids: + continue + name = kb_info.get('name', f'KB_{i}') + desc = kb_info.get('description', 'No description') + kb_descriptions.append(f"{i+1}. {name}: {desc}") + kb_id_map[i+1] = kb_id + + if not kb_descriptions: + return kb_ids + + # Create prompt for LLM + prompt = f"""Given the following user query and available knowledge bases, select the most relevant knowledge bases that would contain information to answer the query. + +User Query: {question} + +Available Knowledge Bases: +{chr(10).join(kb_descriptions)} + +Instructions: +- Return ONLY the numbers of the most relevant knowledge bases (up to {config.kb_top_k}) +- Format: comma-separated numbers, e.g., "1, 3, 5" +- If unsure, include more rather than fewer + +Selected knowledge bases:""" + + # Call LLM + response = chat_mdl.chat(prompt, [], {"temperature": 0.1}) + if not response: + return self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + + # Parse response - extract numbers + import re + numbers = re.findall(r'\d+', response) + selected = [] + for num_str in numbers: + num = int(num_str) + if num in kb_id_map: + selected.append(kb_id_map[num]) + + if selected: + logging.info(f"LLM routing selected {len(selected)} KBs: {selected[:3]}...") + return selected[:config.kb_top_k] + + # Fallback to rule-based if LLM response was unparseable + return self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + + except Exception as e: + logging.error(f"LLM KB routing failed: {e}, falling back to rule-based") + return self._rule_based_kb_routing(question, kb_ids, kb_infos, config) + + def _tier2_document_filtering( + self, + question: str, + tenant_ids, + kb_ids: List[str], + doc_ids: Optional[List[str]], + config: HierarchicalConfig, + embd_mdl=None, + chat_mdl=None, + doc_metadata: Optional[List[Dict[str, Any]]] = None + ) -> List[str]: + """ + Tier 2: Filter documents by relevance before chunk search. + + Supports multiple filtering strategies: + - Vector-based: Lightweight search to find relevant documents + - Metadata filtering: Filter by specified metadata fields + - Similarity matching: Fuzzy matching on document names/summaries + - LLM-based: Use LLM to generate filter conditions + """ + if not config.enable_doc_filtering: + return doc_ids or [] + + if doc_ids: + # Already have doc_ids filter, just limit count + return doc_ids[:config.doc_top_k] + + try: + if isinstance(tenant_ids, str): + tenant_ids = tenant_ids.split(",") + + filtered_docs = set() + + # Strategy 1: LLM-based metadata filtering + if config.use_llm_metadata_filter and chat_mdl and doc_metadata: + llm_filtered = self._llm_metadata_filter( + question, doc_metadata, config, chat_mdl + ) + if llm_filtered: + filtered_docs.update(llm_filtered) + logging.info(f"LLM metadata filter selected {len(llm_filtered)} docs") + + # Strategy 2: Metadata similarity matching + if config.enable_metadata_similarity and doc_metadata: + similarity_filtered = self._metadata_similarity_filter( + question, doc_metadata, config, embd_mdl + ) + if similarity_filtered: + if filtered_docs: + filtered_docs.intersection_update(similarity_filtered) + else: + filtered_docs.update(similarity_filtered) + logging.info(f"Metadata similarity filter: {len(similarity_filtered)} docs") + + # Strategy 3: Vector-based document search (default) + if not filtered_docs: + req = { + "kb_ids": kb_ids, + "question": question, + "size": config.doc_top_k, + "topk": config.doc_top_k * 2, + "similarity": 0.1, # Lower threshold for document filtering + "available_int": 1, + } + + idx_names = [index_name(tid) for tid in tenant_ids] + + # Search with embedding if available + sres = self.search(req, idx_names, kb_ids, embd_mdl, highlight=False) + + if sres and sres.field: + for chunk_id in sres.ids: + doc_id = sres.field[chunk_id].get("doc_id") + if doc_id: + filtered_docs.add(doc_id) + + return list(filtered_docs)[:config.doc_top_k] + + except Exception as e: + logging.error(f"Tier 2 document filtering error: {e}") + return [] + + def _llm_metadata_filter( + self, + question: str, + doc_metadata: List[Dict[str, Any]], + config: HierarchicalConfig, + chat_mdl + ) -> List[str]: + """Use LLM to generate and apply metadata filter conditions.""" + try: + # Get available metadata fields + if not doc_metadata: + return [] + + # Sample metadata to show LLM what's available + sample_fields = set() + sample_values = {} + for doc in doc_metadata[:10]: + for field in config.metadata_fields or doc.keys(): + if field in doc and field not in ['id', 'doc_id']: + sample_fields.add(field) + if field not in sample_values: + sample_values[field] = [] + if len(sample_values[field]) < 3: + sample_values[field].append(str(doc[field])[:50]) + + if not sample_fields: + return [] + + # Build prompt + field_examples = "\n".join([ + f"- {field}: examples = {sample_values.get(field, [])}" + for field in sample_fields + ]) + + prompt = f"""Given a user query and available document metadata fields, generate filter conditions to select relevant documents. + +User Query: {question} + +Available Metadata Fields: +{field_examples} + +Instructions: +- Return filter conditions as JSON: {{"field": "value"}} or {{"field": {{"operator": "value"}}}} +- Supported operators: "contains", "equals", "starts_with" +- Only filter on fields that are clearly relevant to the query +- Return empty {{}} if no clear filter applies + +Filter conditions (JSON only):""" + + response = chat_mdl.chat(prompt, [], {"temperature": 0.1}) + if not response: + return [] + + # Parse JSON response + try: + # Extract JSON from response + json_match = re.search(r'\{[^}]*\}', response) + if json_match: + filters = json.loads(json_match.group()) + else: + return [] + except json.JSONDecodeError: + return [] + + if not filters: + return [] + + # Apply filters to documents + filtered_ids = [] + for doc in doc_metadata: + doc_id = doc.get('id') or doc.get('doc_id') + if not doc_id: + continue + + matches = True + for field, condition in filters.items(): + doc_value = str(doc.get(field, '')).lower() + + if isinstance(condition, dict): + op = list(condition.keys())[0] + val = str(list(condition.values())[0]).lower() + if op == "contains": + matches = val in doc_value + elif op == "equals": + matches = val == doc_value + elif op == "starts_with": + matches = doc_value.startswith(val) + else: + # Simple equality + matches = str(condition).lower() in doc_value + + if not matches: + break + + if matches: + filtered_ids.append(doc_id) + + return filtered_ids + + except Exception as e: + logging.error(f"LLM metadata filter error: {e}") + return [] + + def _metadata_similarity_filter( + self, + question: str, + doc_metadata: List[Dict[str, Any]], + config: HierarchicalConfig, + embd_mdl=None + ) -> List[str]: + """Filter documents by metadata text similarity.""" + try: + if not embd_mdl or not doc_metadata: + return [] + + # Get query embedding + query_vec, _ = embd_mdl.encode_queries(question) + if not query_vec or len(query_vec) == 0: + return [] + + query_vec = np.array(query_vec) + + # Score each document by metadata similarity + doc_scores = [] + for doc in doc_metadata: + doc_id = doc.get('id') or doc.get('doc_id') + if not doc_id: + continue + + # Combine relevant text fields + text_parts = [] + for field in ['name', 'title', 'summary', 'description', 'docnm_kwd']: + if field in doc and doc[field]: + text_parts.append(str(doc[field])) + + if not text_parts: + continue + + doc_text = ' '.join(text_parts) + + # Get document text embedding + doc_vec, _ = embd_mdl.encode([doc_text]) + if not doc_vec or len(doc_vec) == 0: + continue + + doc_vec = np.array(doc_vec[0]) + + # Cosine similarity + similarity = np.dot(query_vec, doc_vec) / ( + np.linalg.norm(query_vec) * np.linalg.norm(doc_vec) + 1e-8 + ) + + if similarity >= config.metadata_similarity_threshold: + doc_scores.append((doc_id, float(similarity))) + + # Sort by similarity and return top docs + doc_scores.sort(key=lambda x: x[1], reverse=True) + return [doc_id for doc_id, _ in doc_scores[:config.doc_top_k]] + + except Exception as e: + logging.error(f"Metadata similarity filter error: {e}") + return [] + + def _tier3_summary_mapping_retrieval( + self, + question: str, + embd_mdl, + tenant_ids, + kb_ids: List[str], + doc_ids: Optional[List[str]], + page: int, + page_size: int, + similarity_threshold: float, + vector_similarity_weight: float, + top: int, + aggs: bool, + rerank_mdl, + highlight: bool, + rank_feature: Optional[dict], + config: HierarchicalConfig + ) -> Dict[str, Any]: + """ + Tier 3 with parent-child chunking and summary mapping. + + First matches macro-themes via summary/parent vectors, + then maps to original child chunks for detailed information. + """ + try: + if isinstance(tenant_ids, str): + tenant_ids = tenant_ids.split(",") + + idx_names = [index_name(tid) for tid in tenant_ids] + + # Step 1: Search for parent/summary chunks first + # Look for chunks with mom_id (parent chunks) or summary content + parent_req = { + "kb_ids": kb_ids, + "doc_ids": doc_ids, + "question": question, + "size": config.chunk_top_k * 3, # Get more parents to find children + "topk": top, + "similarity": similarity_threshold * 0.8, # Slightly lower for parents + "available_int": 1, + } + + parent_sres = self.search( + parent_req, idx_names, kb_ids, embd_mdl, + highlight=False, rank_feature=rank_feature + ) + + if not parent_sres or not parent_sres.ids: + # Fallback to standard retrieval + return self.retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=kb_ids, + page=page, + page_size=page_size, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=doc_ids, + aggs=aggs, + rerank_mdl=rerank_mdl, + highlight=highlight, + rank_feature=rank_feature, + ) + + # Step 2: Find child chunks for matched parents + parent_ids = set() + child_doc_ids = set() + + for chunk_id in parent_sres.ids: + chunk = parent_sres.field.get(chunk_id, {}) + mom_id = chunk.get("mom_id") + doc_id = chunk.get("doc_id") + + if mom_id: + # This is a child chunk, get its parent + parent_ids.add(mom_id) + else: + # This might be a parent, use its ID + parent_ids.add(chunk_id) + + if doc_id: + child_doc_ids.add(doc_id) + + # Step 3: Retrieve child chunks for the matched parents + # Search within the documents that contain matched parents + child_req = { + "kb_ids": kb_ids, + "doc_ids": list(child_doc_ids) if child_doc_ids else doc_ids, + "question": question, + "size": page_size * 2, + "topk": top, + "similarity": similarity_threshold, + "available_int": 1, + } + + child_sres = self.search( + child_req, idx_names, kb_ids, embd_mdl, + highlight=highlight, rank_feature=rank_feature + ) + + if not child_sres or not child_sres.ids: + # Use parent results if no children found + child_sres = parent_sres + + # Step 4: Rerank and format results + if rerank_mdl and child_sres.total > 0: + sim, tsim, vsim = self.rerank_by_model( + rerank_mdl, child_sres, question, + 1 - vector_similarity_weight, vector_similarity_weight, + rank_feature=rank_feature, + ) + else: + sim, tsim, vsim = self.rerank( + child_sres, question, + 1 - vector_similarity_weight, vector_similarity_weight, + rank_feature=rank_feature, + ) + + # Format results + ranks = {"total": 0, "chunks": [], "doc_aggs": {}} + + sim_np = np.array(sim, dtype=np.float64) + if sim_np.size == 0: + ranks["doc_aggs"] = [] + return ranks + + sorted_idx = np.argsort(sim_np * -1) + valid_idx = [int(i) for i in sorted_idx if sim_np[i] >= similarity_threshold] + ranks["total"] = len(valid_idx) + + if not valid_idx: + ranks["doc_aggs"] = [] + return ranks + + # Get page of results + begin = (page - 1) * page_size + end = begin + page_size + page_idx = valid_idx[begin:end] + + dim = len(child_sres.query_vector) if child_sres.query_vector else 0 + vector_column = f"q_{dim}_vec" if dim else "" + zero_vector = [0.0] * dim if dim else [] + + for i in page_idx: + chunk_id = child_sres.ids[i] + chunk = child_sres.field[chunk_id] + + d = { + "chunk_id": chunk_id, + "content_ltks": chunk.get("content_ltks", ""), + "content_with_weight": chunk.get("content_with_weight", ""), + "doc_id": chunk.get("doc_id", ""), + "docnm_kwd": chunk.get("docnm_kwd", ""), + "kb_id": chunk.get("kb_id", ""), + "important_kwd": chunk.get("important_kwd", []), + "image_id": chunk.get("img_id", ""), + "similarity": float(sim_np[i]), + "vector_similarity": float(vsim[i]) if i < len(vsim) else 0.0, + "term_similarity": float(tsim[i]) if i < len(tsim) else 0.0, + "vector": chunk.get(vector_column, zero_vector), + "positions": chunk.get("position_int", []), + "doc_type_kwd": chunk.get("doc_type_kwd", ""), + "mom_id": chunk.get("mom_id", ""), + "parent_matched": chunk_id in parent_ids or chunk.get("mom_id") in parent_ids, + } + + if highlight and child_sres.highlight and chunk_id in child_sres.highlight: + d["highlight"] = remove_redundant_spaces(child_sres.highlight[chunk_id]) + + ranks["chunks"].append(d) + + # Build doc aggregations + if aggs: + for i in valid_idx: + chunk = child_sres.field[child_sres.ids[i]] + dnm = chunk.get("docnm_kwd", "") + did = chunk.get("doc_id", "") + if dnm and dnm not in ranks["doc_aggs"]: + ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0} + if dnm: + ranks["doc_aggs"][dnm]["count"] += 1 + + ranks["doc_aggs"] = [ + {"doc_name": k, "doc_id": v["doc_id"], "count": v["count"]} + for k, v in sorted(ranks["doc_aggs"].items(), key=lambda x: x[1]["count"] * -1) + ] + else: + ranks["doc_aggs"] = [] + + logging.info(f"Summary mapping retrieval: {len(parent_ids)} parents -> {len(ranks['chunks'])} chunks") + return ranks + + except Exception as e: + logging.error(f"Summary mapping retrieval error: {e}") + # Fallback to standard retrieval + return self.retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=kb_ids, + page=page, + page_size=page_size, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=doc_ids, + aggs=aggs, + rerank_mdl=rerank_mdl, + highlight=highlight, + rank_feature=rank_feature, + ) + + def _apply_custom_keyword_extraction( + self, + chunks: List[Dict[str, Any]], + prompt_template: str, + embd_mdl=None + ) -> List[Dict[str, Any]]: + """ + Apply customizable prompts for chunk keyword extraction. + + Allows users to configure custom prompts for domain-specific + keyword extraction to better align with their semantics. + """ + try: + for chunk in chunks: + content = chunk.get("content_with_weight", "") or chunk.get("content_ltks", "") + if not content: + continue + + # Extract keywords using tokenizer with custom weighting + tokens = rag_tokenizer.tokenize(content).split() + + # Apply custom prompt logic if it contains extraction rules + if "important" in prompt_template.lower(): + # Prioritize capitalized words and technical terms + important_tokens = [ + t for t in tokens + if len(t) > 3 and (t[0].isupper() or '_' in t or '-' in t) + ] + if important_tokens: + existing = chunk.get("important_kwd", []) + if isinstance(existing, str): + existing = [existing] + chunk["important_kwd"] = list(set(existing + important_tokens[:10])) + + if "question" in prompt_template.lower(): + # Generate potential questions from content + # This is a simplified version - full implementation would use LLM + if "?" not in content and len(content) > 50: + # Add a generated question hint + chunk["question_hint"] = f"What does this section explain about {tokens[0] if tokens else 'the topic'}?" + + return chunks + + except Exception as e: + logging.error(f"Custom keyword extraction error: {e}") + return chunks + + def _apply_llm_question_generation( + self, + chunks: List[Dict[str, Any]], + prompt_template: Optional[str], + chat_mdl + ) -> List[Dict[str, Any]]: + """ + Apply LLM-based question generation for chunks. + + Generates potential questions that each chunk could answer, + improving retrieval by enabling question-to-question matching. + """ + if not chat_mdl: + return chunks + + try: + default_prompt = """Based on the following content, generate 2-3 concise questions that this content could answer. +Return only the questions, one per line. + +Content: {content} + +Questions:""" + + prompt = prompt_template or default_prompt + + for chunk in chunks: + content = chunk.get("content_with_weight", "") or chunk.get("content_ltks", "") + if not content or len(content) < 50: + continue + + # Truncate content if too long + content_truncated = content[:1500] if len(content) > 1500 else content + + try: + # Format prompt with content + formatted_prompt = prompt.replace("{content}", content_truncated) + + # Call LLM + response = chat_mdl.chat(formatted_prompt, [], {"temperature": 0.3, "max_tokens": 200}) + + if response: + # Parse questions from response + questions = [ + q.strip() for q in response.strip().split('\n') + if q.strip() and '?' in q + ] + + if questions: + chunk["generated_questions"] = questions[:3] + logging.debug(f"Generated {len(questions)} questions for chunk") + + except Exception as e: + logging.warning(f"Question generation failed for chunk: {e}") + continue + + return chunks + + except Exception as e: + logging.error(f"LLM question generation error: {e}") + return chunks + + def generate_document_metadata( + self, + doc_id: str, + content: str, + chat_mdl=None, + embd_mdl=None + ) -> Dict[str, Any]: + """ + Generate enhanced metadata for a document. + + This is a Data Pipeline enhancement hook that generates: + - Document summary + - Key topics/themes + - Suggested questions + - Category classification + + Can be called during document ingestion to enrich metadata. + """ + metadata = { + "doc_id": doc_id, + "generated": True, + "summary": "", + "topics": [], + "suggested_questions": [], + "category": "", + } + + if not chat_mdl or not content: + return metadata + + try: + # Truncate content for LLM + content_truncated = content[:3000] if len(content) > 3000 else content + + prompt = f"""Analyze the following document content and provide: +1. A brief summary (2-3 sentences) +2. Key topics/themes (comma-separated list) +3. 3 questions this document could answer +4. A category classification + +Content: +{content_truncated} + +Respond in this exact format: +SUMMARY: +TOPICS: , , +QUESTIONS: +- +- +- +CATEGORY: """ + + response = chat_mdl.chat(prompt, [], {"temperature": 0.2, "max_tokens": 500}) + + if response: + # Parse response + lines = response.strip().split('\n') + for line in lines: + line = line.strip() + if line.startswith("SUMMARY:"): + metadata["summary"] = line[8:].strip() + elif line.startswith("TOPICS:"): + topics = line[7:].strip().split(',') + metadata["topics"] = [t.strip() for t in topics if t.strip()] + elif line.startswith("- ") and "?" in line: + metadata["suggested_questions"].append(line[2:].strip()) + elif line.startswith("CATEGORY:"): + metadata["category"] = line[9:].strip() + + # Generate summary embedding if embedding model available + if embd_mdl and metadata["summary"]: + try: + summary_vec, _ = embd_mdl.encode([metadata["summary"]]) + if summary_vec and len(summary_vec) > 0: + metadata["summary_vector"] = summary_vec[0] + except Exception: + pass + + logging.info(f"Generated metadata for doc {doc_id}: {len(metadata['topics'])} topics, {len(metadata['suggested_questions'])} questions") + return metadata + + except Exception as e: + logging.error(f"Document metadata generation error: {e}") + return metadata + def sql_retrieval(self, sql, fetch_size=128, format="json"): tbl = self.dataStore.sql(sql, fetch_size, format) return tbl diff --git a/test/unit_test/nlp/__init__.py b/test/unit_test/nlp/__init__.py new file mode 100644 index 000000000..5dbbfb4a4 --- /dev/null +++ b/test/unit_test/nlp/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unit tests for NLP module.""" diff --git a/test/unit_test/nlp/test_hierarchical_retrieval.py b/test/unit_test/nlp/test_hierarchical_retrieval.py new file mode 100644 index 000000000..dc52d38cf --- /dev/null +++ b/test/unit_test/nlp/test_hierarchical_retrieval.py @@ -0,0 +1,699 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for hierarchical retrieval integration in search.py. +""" + +import pytest +from unittest.mock import Mock + +from rag.nlp.search import ( + HierarchicalConfig, + HierarchicalResult, + KBRetrievalParams, + Dealer, + index_name, +) + + +class TestHierarchicalConfig: + """Test HierarchicalConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = HierarchicalConfig() + + assert config.enabled is False + assert config.enable_kb_routing is True + assert config.kb_routing_threshold == 0.3 + assert config.kb_top_k == 3 + assert config.enable_doc_filtering is True + assert config.doc_top_k == 100 + assert config.chunk_top_k == 10 + + def test_custom_values(self): + """Test custom configuration values.""" + config = HierarchicalConfig( + enabled=True, + kb_routing_threshold=0.5, + kb_top_k=5, + doc_top_k=50, + ) + + assert config.enabled is True + assert config.kb_routing_threshold == 0.5 + assert config.kb_top_k == 5 + assert config.doc_top_k == 50 + + +class TestHierarchicalResult: + """Test HierarchicalResult dataclass.""" + + def test_default_values(self): + """Test default result values.""" + result = HierarchicalResult() + + assert result.selected_kb_ids == [] + assert result.filtered_doc_ids == [] + assert result.tier1_time_ms == 0.0 + assert result.tier2_time_ms == 0.0 + assert result.tier3_time_ms == 0.0 + assert result.total_time_ms == 0.0 + + def test_populated_result(self): + """Test result with data.""" + result = HierarchicalResult( + selected_kb_ids=["kb1", "kb2"], + filtered_doc_ids=["doc1", "doc2", "doc3"], + tier1_time_ms=10.5, + tier2_time_ms=25.3, + tier3_time_ms=100.2, + total_time_ms=136.0, + ) + + assert len(result.selected_kb_ids) == 2 + assert len(result.filtered_doc_ids) == 3 + assert result.total_time_ms == 136.0 + + +class TestTier1KBRouting: + """Test Tier 1: KB Routing logic.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_routing_disabled(self, mock_dealer): + """Test that routing returns all KBs when disabled.""" + config = HierarchicalConfig(enable_kb_routing=False) + kb_ids = ["kb1", "kb2", "kb3"] + + result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config) + + assert result == kb_ids + + def test_routing_no_kb_infos(self, mock_dealer): + """Test that routing returns all KBs when no KB info provided.""" + config = HierarchicalConfig(enable_kb_routing=True) + kb_ids = ["kb1", "kb2", "kb3"] + + result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config) + + assert result == kb_ids + + def test_routing_few_kbs(self, mock_dealer): + """Test that routing returns all KBs when count <= kb_top_k.""" + config = HierarchicalConfig(enable_kb_routing=True, kb_top_k=5) + kb_ids = ["kb1", "kb2", "kb3"] + kb_infos = [ + {"id": "kb1", "name": "Finance KB", "description": "Financial documents"}, + {"id": "kb2", "name": "HR KB", "description": "Human resources"}, + {"id": "kb3", "name": "Tech KB", "description": "Technical docs"}, + ] + + result = mock_dealer._tier1_kb_routing("test query", kb_ids, kb_infos, config) + + assert result == kb_ids + + def test_routing_selects_relevant_kbs(self, mock_dealer): + """Test that routing selects KBs based on keyword overlap.""" + config = HierarchicalConfig( + enable_kb_routing=True, + kb_top_k=2, + kb_routing_threshold=0.1 + ) + kb_ids = ["kb1", "kb2", "kb3", "kb4"] + kb_infos = [ + {"id": "kb1", "name": "Finance Reports", "description": "Financial analysis and reports"}, + {"id": "kb2", "name": "HR Policies", "description": "Human resources policies"}, + {"id": "kb3", "name": "Technical Documentation", "description": "Engineering docs"}, + {"id": "kb4", "name": "Financial Statements", "description": "Quarterly financial data"}, + ] + + # Query about finance should select finance-related KBs + result = mock_dealer._tier1_kb_routing("financial report analysis", kb_ids, kb_infos, config) + + # Should select at most kb_top_k KBs + assert len(result) <= config.kb_top_k + # Should include finance-related KBs + assert any(kb in result for kb in ["kb1", "kb4"]) + + +class TestTier2DocumentFiltering: + """Test Tier 2: Document Filtering logic.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_filtering_disabled(self, mock_dealer): + """Test that filtering returns empty when disabled.""" + config = HierarchicalConfig(enable_doc_filtering=False) + + result = mock_dealer._tier2_document_filtering( + "test query", ["tenant1"], ["kb1"], None, config + ) + + assert result == [] + + def test_filtering_with_existing_doc_ids(self, mock_dealer): + """Test that filtering limits existing doc_ids.""" + config = HierarchicalConfig(enable_doc_filtering=True, doc_top_k=2) + doc_ids = ["doc1", "doc2", "doc3", "doc4"] + + result = mock_dealer._tier2_document_filtering( + "test query", ["tenant1"], ["kb1"], doc_ids, config + ) + + assert len(result) == 2 + assert result == ["doc1", "doc2"] + + +class TestHierarchicalRetrieval: + """Test the full hierarchical retrieval flow.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance with mocked methods.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + + # Mock the retrieval method + dealer.retrieval = Mock(return_value={ + "total": 5, + "chunks": [ + {"chunk_id": "c1", "content_with_weight": "test content 1"}, + {"chunk_id": "c2", "content_with_weight": "test content 2"}, + ], + "doc_aggs": [], + }) + + return dealer + + def test_hierarchical_retrieval_basic(self, mock_dealer): + """Test basic hierarchical retrieval flow.""" + config = HierarchicalConfig( + enabled=True, + enable_kb_routing=False, # Skip routing for this test + enable_doc_filtering=False, # Skip filtering for this test + ) + + result = mock_dealer.hierarchical_retrieval( + question="test query", + embd_mdl=Mock(), + tenant_ids=["tenant1"], + kb_ids=["kb1", "kb2"], + hierarchical_config=config, + ) + + # Should have chunks from retrieval + assert "chunks" in result + assert len(result["chunks"]) == 2 + + # Should have hierarchical metadata + assert "hierarchical_metadata" in result + metadata = result["hierarchical_metadata"] + assert "tier1_time_ms" in metadata + assert "tier2_time_ms" in metadata + assert "tier3_time_ms" in metadata + assert "total_time_ms" in metadata + + def test_hierarchical_retrieval_with_kb_infos(self, mock_dealer): + """Test hierarchical retrieval with KB information.""" + config = HierarchicalConfig( + enabled=True, + enable_kb_routing=True, + kb_top_k=2, + ) + + kb_infos = [ + {"id": "kb1", "name": "Finance", "description": "Financial docs"}, + {"id": "kb2", "name": "HR", "description": "HR policies"}, + ] + + result = mock_dealer.hierarchical_retrieval( + question="financial report", + embd_mdl=Mock(), + tenant_ids=["tenant1"], + kb_ids=["kb1", "kb2"], + kb_infos=kb_infos, + hierarchical_config=config, + ) + + assert "hierarchical_metadata" in result + assert "selected_kb_ids" in result["hierarchical_metadata"] + + def test_hierarchical_retrieval_empty_query(self, mock_dealer): + """Test hierarchical retrieval with empty query.""" + # Mock retrieval to return empty for empty query + mock_dealer.retrieval = Mock(return_value={ + "total": 0, + "chunks": [], + "doc_aggs": [], + }) + + config = HierarchicalConfig(enabled=True) + + result = mock_dealer.hierarchical_retrieval( + question="", + embd_mdl=Mock(), + tenant_ids=["tenant1"], + kb_ids=["kb1"], + hierarchical_config=config, + ) + + assert result["total"] == 0 + assert result["chunks"] == [] + + +class TestIndexName: + """Test index_name utility function.""" + + def test_index_name_format(self): + """Test index name formatting.""" + assert index_name("user123") == "ragflow_user123" + assert index_name("tenant_abc") == "ragflow_tenant_abc" + + +class TestHierarchicalConfigAdvanced: + """Test advanced HierarchicalConfig options.""" + + def test_all_config_options(self): + """Test all configuration options.""" + config = HierarchicalConfig( + enabled=True, + enable_kb_routing=True, + kb_routing_method="llm_based", + kb_routing_threshold=0.5, + kb_top_k=5, + enable_doc_filtering=True, + doc_top_k=50, + metadata_fields=["department", "doc_type"], + enable_metadata_similarity=True, + metadata_similarity_threshold=0.8, + use_llm_metadata_filter=True, + chunk_top_k=20, + enable_parent_child=True, + use_summary_mapping=True, + keyword_extraction_prompt="Extract important technical terms", + question_generation_prompt="Generate questions about the content", + ) + + assert config.kb_routing_method == "llm_based" + assert config.metadata_fields == ["department", "doc_type"] + assert config.enable_metadata_similarity is True + assert config.use_llm_metadata_filter is True + assert config.enable_parent_child is True + assert config.use_summary_mapping is True + assert config.keyword_extraction_prompt is not None + + +class TestLLMKBRouting: + """Test LLM-based KB routing.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_llm_routing_fallback_no_model(self, mock_dealer): + """Test LLM routing falls back when no model provided.""" + config = HierarchicalConfig( + enable_kb_routing=True, + kb_routing_method="llm_based", + kb_top_k=2 + ) + kb_ids = ["kb1", "kb2", "kb3", "kb4"] + kb_infos = [ + {"id": "kb1", "name": "Finance", "description": "Financial docs"}, + {"id": "kb2", "name": "HR", "description": "HR policies"}, + {"id": "kb3", "name": "Tech", "description": "Technical docs"}, + {"id": "kb4", "name": "Legal", "description": "Legal documents"}, + ] + + # Without chat_mdl, should fall back to rule-based + result = mock_dealer._tier1_kb_routing( + "financial report", kb_ids, kb_infos, config, chat_mdl=None + ) + + # Should still return results (from rule-based fallback) + assert len(result) <= config.kb_top_k + + def test_routing_method_all(self, mock_dealer): + """Test 'all' routing method returns all KBs.""" + config = HierarchicalConfig( + enable_kb_routing=True, + kb_routing_method="all", + kb_top_k=2 + ) + kb_ids = ["kb1", "kb2", "kb3", "kb4"] + kb_infos = [ + {"id": "kb1", "name": "Finance", "description": "Financial docs"}, + {"id": "kb2", "name": "HR", "description": "HR policies"}, + ] + + result = mock_dealer._tier1_kb_routing( + "any query", kb_ids, kb_infos, config + ) + + assert result == kb_ids + + +class TestMetadataSimilarityFilter: + """Test metadata similarity filtering.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_similarity_filter_no_model(self, mock_dealer): + """Test similarity filter returns empty without embedding model.""" + config = HierarchicalConfig( + enable_metadata_similarity=True, + metadata_similarity_threshold=0.7 + ) + + doc_metadata = [ + {"id": "doc1", "name": "Finance Report", "summary": "Q1 financial analysis"}, + {"id": "doc2", "name": "HR Policy", "summary": "Employee guidelines"}, + ] + + result = mock_dealer._metadata_similarity_filter( + "financial analysis", doc_metadata, config, embd_mdl=None + ) + + assert result == [] + + def test_similarity_filter_no_metadata(self, mock_dealer): + """Test similarity filter returns empty without metadata.""" + config = HierarchicalConfig(enable_metadata_similarity=True) + mock_embd = Mock() + + result = mock_dealer._metadata_similarity_filter( + "test query", [], config, embd_mdl=mock_embd + ) + + assert result == [] + + +class TestParentChildRetrieval: + """Test parent-child chunking with summary mapping.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance with mocked methods.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + + # Mock the search method + mock_search_result = Mock() + mock_search_result.ids = ["chunk1", "chunk2"] + mock_search_result.field = { + "chunk1": { + "content_ltks": "parent content", + "content_with_weight": "parent content", + "doc_id": "doc1", + "docnm_kwd": "test.pdf", + "kb_id": "kb1", + "mom_id": "", + }, + "chunk2": { + "content_ltks": "child content", + "content_with_weight": "child content", + "doc_id": "doc1", + "docnm_kwd": "test.pdf", + "kb_id": "kb1", + "mom_id": "chunk1", + }, + } + mock_search_result.total = 2 + mock_search_result.query_vector = [0.1] * 768 + mock_search_result.highlight = {} + + dealer.search = Mock(return_value=mock_search_result) + dealer.rerank = Mock(return_value=([0.9, 0.8], [0.5, 0.4], [0.9, 0.8])) + + return dealer + + def test_summary_mapping_config(self): + """Test summary mapping configuration.""" + config = HierarchicalConfig( + enable_parent_child=True, + use_summary_mapping=True, + ) + + assert config.enable_parent_child is True + assert config.use_summary_mapping is True + + +class TestCustomKeywordExtraction: + """Test customizable keyword extraction.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_keyword_extraction_important(self, mock_dealer): + """Test keyword extraction with 'important' prompt.""" + chunks = [ + { + "content_with_weight": "The API_KEY and DatabaseConnection are critical components.", + "important_kwd": ["existing_term"], # Pre-existing keyword + } + ] + + result = mock_dealer._apply_custom_keyword_extraction( + chunks, "Extract important technical terms", None + ) + + # Should return the chunk (possibly with extracted keywords) + assert len(result) == 1 + # Should preserve existing keywords at minimum + assert "existing_term" in result[0].get("important_kwd", []) + + def test_keyword_extraction_question(self, mock_dealer): + """Test keyword extraction with 'question' prompt.""" + chunks = [ + { + "content_with_weight": "This section explains the authentication flow for the system. Users must provide valid credentials.", + "important_kwd": [], + } + ] + + result = mock_dealer._apply_custom_keyword_extraction( + chunks, "Generate questions about the content", None + ) + + assert len(result) == 1 + # Should have a question hint + assert "question_hint" in result[0] + + def test_keyword_extraction_empty_content(self, mock_dealer): + """Test keyword extraction with empty content.""" + chunks = [{"content_with_weight": "", "important_kwd": []}] + + result = mock_dealer._apply_custom_keyword_extraction( + chunks, "Extract important terms", None + ) + + assert len(result) == 1 + + +class TestKBRetrievalParams: + """Test per-KB retrieval parameters.""" + + def test_default_params(self): + """Test default KB params.""" + params = KBRetrievalParams(kb_id="kb1") + + assert params.kb_id == "kb1" + assert params.vector_similarity_weight == 0.7 + assert params.similarity_threshold == 0.2 + assert params.top_k == 1024 + assert params.rerank_enabled is True + + def test_custom_params(self): + """Test custom KB params.""" + params = KBRetrievalParams( + kb_id="finance_kb", + vector_similarity_weight=0.9, + similarity_threshold=0.3, + top_k=500, + rerank_enabled=False + ) + + assert params.kb_id == "finance_kb" + assert params.vector_similarity_weight == 0.9 + assert params.similarity_threshold == 0.3 + assert params.top_k == 500 + assert params.rerank_enabled is False + + def test_kb_params_in_config(self): + """Test KB params integration in HierarchicalConfig.""" + kb_params = { + "kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.8), + "kb2": KBRetrievalParams(kb_id="kb2", similarity_threshold=0.4), + } + + config = HierarchicalConfig( + enabled=True, + kb_params=kb_params + ) + + assert len(config.kb_params) == 2 + assert config.kb_params["kb1"].vector_similarity_weight == 0.8 + assert config.kb_params["kb2"].similarity_threshold == 0.4 + + +class TestLLMQuestionGeneration: + """Test LLM-based question generation.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_question_generation_no_model(self, mock_dealer): + """Test question generation returns chunks unchanged without model.""" + chunks = [{"content_with_weight": "Test content", "important_kwd": []}] + + result = mock_dealer._apply_llm_question_generation(chunks, None, None) + + assert result == chunks + assert "generated_questions" not in result[0] + + def test_question_generation_with_mock_model(self, mock_dealer): + """Test question generation with mock LLM.""" + mock_chat = Mock() + mock_chat.chat = Mock(return_value="What is the main topic?\nHow does this work?") + + chunks = [{ + "content_with_weight": "This is a detailed explanation of the authentication system. It uses OAuth2 for secure access.", + "important_kwd": [] + }] + + result = mock_dealer._apply_llm_question_generation(chunks, None, mock_chat) + + assert len(result) == 1 + assert "generated_questions" in result[0] + assert len(result[0]["generated_questions"]) == 2 + + +class TestDocumentMetadataGeneration: + """Test document metadata generation.""" + + @pytest.fixture + def mock_dealer(self): + """Create a mock Dealer instance.""" + mock_datastore = Mock() + dealer = Dealer(mock_datastore) + return dealer + + def test_metadata_generation_no_model(self, mock_dealer): + """Test metadata generation returns empty without model.""" + result = mock_dealer.generate_document_metadata("doc1", "content", None, None) + + assert result["doc_id"] == "doc1" + assert result["generated"] is True + assert result["summary"] == "" + assert result["topics"] == [] + + def test_metadata_generation_with_mock_model(self, mock_dealer): + """Test metadata generation with mock LLM.""" + mock_chat = Mock() + mock_chat.chat = Mock(return_value="""SUMMARY: This document explains authentication. +TOPICS: security, OAuth, authentication +QUESTIONS: +- What is OAuth? +- How does authentication work? +CATEGORY: Technical Documentation""") + + result = mock_dealer.generate_document_metadata( + "doc1", + "This is content about authentication and security.", + mock_chat, + None + ) + + assert result["doc_id"] == "doc1" + assert "authentication" in result["summary"] + assert len(result["topics"]) == 3 + assert "security" in result["topics"] + assert len(result["suggested_questions"]) == 2 + assert result["category"] == "Technical Documentation" + + +class TestFullHierarchicalConfig: + """Test complete hierarchical configuration.""" + + def test_all_features_enabled(self): + """Test config with all features enabled.""" + config = HierarchicalConfig( + enabled=True, + # Tier 1 + enable_kb_routing=True, + kb_routing_method="llm_based", + kb_routing_threshold=0.4, + kb_top_k=5, + kb_params={ + "kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.9) + }, + # Tier 2 + enable_doc_filtering=True, + doc_top_k=50, + metadata_fields=["department", "author"], + enable_metadata_similarity=True, + metadata_similarity_threshold=0.8, + use_llm_metadata_filter=True, + # Tier 3 + chunk_top_k=20, + enable_parent_child=True, + use_summary_mapping=True, + # Prompts + keyword_extraction_prompt="Extract domain-specific terms", + question_generation_prompt="Generate FAQ questions", + use_llm_question_generation=True, + ) + + assert config.enabled is True + assert config.kb_routing_method == "llm_based" + assert len(config.kb_params) == 1 + assert config.enable_metadata_similarity is True + assert config.use_llm_metadata_filter is True + assert config.enable_parent_child is True + assert config.use_summary_mapping is True + assert config.use_llm_question_generation is True + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/unit_test/services/test_metadata_service.py b/test/unit_test/services/test_metadata_service.py new file mode 100644 index 000000000..46b435dc3 --- /dev/null +++ b/test/unit_test/services/test_metadata_service.py @@ -0,0 +1,230 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for MetadataService. + +Tests batch CRUD operations for document metadata management. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + + +class TestMetadataServiceBatchGet: + """Test batch_get_metadata functionality.""" + + def test_batch_get_empty_ids(self): + """Test batch get with empty doc_ids returns empty dict.""" + from api.db.services.metadata_service import MetadataService + + with patch.object(MetadataService, 'batch_get_metadata', return_value={}): + result = MetadataService.batch_get_metadata([]) + assert result == {} + + def test_batch_get_with_fields_filter(self): + """Test batch get filters to requested fields.""" + # This tests the logic of field filtering + full_metadata = {"field1": "value1", "field2": "value2", "field3": "value3"} + requested_fields = ["field1", "field3"] + + filtered = {k: v for k, v in full_metadata.items() if k in requested_fields} + + assert "field1" in filtered + assert "field3" in filtered + assert "field2" not in filtered + + +class TestMetadataServiceBatchUpdate: + """Test batch_update_metadata functionality.""" + + def test_update_merge_logic(self): + """Test metadata merge logic.""" + existing = {"field1": "old_value", "field2": "keep_this"} + new_metadata = {"field1": "new_value", "field3": "added"} + + # Merge logic + existing.update(new_metadata) + + assert existing["field1"] == "new_value" # Updated + assert existing["field2"] == "keep_this" # Preserved + assert existing["field3"] == "added" # Added + + def test_update_replace_logic(self): + """Test metadata replace logic.""" + existing = {"field1": "old_value", "field2": "keep_this"} + new_metadata = {"field1": "new_value", "field3": "added"} + + # Replace logic (don't merge) + result = new_metadata + + assert result["field1"] == "new_value" + assert "field2" not in result # Not preserved + assert result["field3"] == "added" + + +class TestMetadataServiceBatchDelete: + """Test batch_delete_metadata_fields functionality.""" + + def test_delete_fields_logic(self): + """Test field deletion logic.""" + metadata = {"field1": "value1", "field2": "value2", "field3": "value3"} + fields_to_delete = ["field1", "field3"] + + for field in fields_to_delete: + if field in metadata: + del metadata[field] + + assert "field1" not in metadata + assert "field2" in metadata + assert "field3" not in metadata + + +class TestMetadataServiceSearch: + """Test search_by_metadata functionality.""" + + def test_equals_filter(self): + """Test equals filter logic.""" + doc_value = "Technical" + condition = "Technical" + + matches = str(doc_value) == str(condition) + assert matches is True + + def test_contains_filter(self): + """Test contains filter logic.""" + doc_value = "Technical Documentation" + condition = {"contains": "Technical"} + + val = condition["contains"] + matches = str(val).lower() in str(doc_value).lower() + assert matches is True + + def test_starts_with_filter(self): + """Test starts_with filter logic.""" + doc_value = "Technical Documentation" + condition = {"starts_with": "Tech"} + + val = condition["starts_with"] + matches = str(doc_value).lower().startswith(str(val).lower()) + assert matches is True + + def test_gt_filter(self): + """Test greater than filter logic.""" + doc_value = 2023 + condition = {"gt": 2020} + + val = condition["gt"] + matches = float(doc_value) > float(val) + assert matches is True + + def test_lt_filter(self): + """Test less than filter logic.""" + doc_value = 2019 + condition = {"lt": 2020} + + val = condition["lt"] + matches = float(doc_value) < float(val) + assert matches is True + + def test_in_filter(self): + """Test in filter logic.""" + doc_value = "Technical" + condition = {"in": ["Technical", "Legal", "HR"]} + + val = condition["in"] + matches = doc_value in val + assert matches is True + + +class TestMetadataServiceSchema: + """Test get_metadata_schema functionality.""" + + def test_schema_type_detection(self): + """Test type detection in schema.""" + values = [ + ("string_value", "str"), + (123, "int"), + (12.5, "float"), + (True, "bool"), + (["a", "b"], "list"), + ] + + for value, expected_type in values: + detected_type = type(value).__name__ + assert detected_type == expected_type + + def test_schema_sample_values_limit(self): + """Test sample values are limited.""" + sample_values = set() + max_samples = 10 + + for i in range(20): + if len(sample_values) < max_samples: + sample_values.add(f"value_{i}") + + assert len(sample_values) == max_samples + + +class TestMetadataServiceStatistics: + """Test get_metadata_statistics functionality.""" + + def test_coverage_calculation(self): + """Test metadata coverage calculation.""" + total_docs = 100 + docs_with_metadata = 80 + + coverage = docs_with_metadata / total_docs if total_docs > 0 else 0 + + assert coverage == 0.8 + + def test_coverage_zero_docs(self): + """Test coverage with zero documents.""" + total_docs = 0 + docs_with_metadata = 0 + + coverage = docs_with_metadata / total_docs if total_docs > 0 else 0 + + assert coverage == 0 + + +class TestMetadataServiceCopy: + """Test copy_metadata functionality.""" + + def test_copy_all_fields(self): + """Test copying all metadata fields.""" + source_meta = {"field1": "value1", "field2": "value2"} + + # Copy all + copied = source_meta.copy() + + assert copied == source_meta + assert copied is not source_meta # Different object + + def test_copy_specific_fields(self): + """Test copying specific metadata fields.""" + source_meta = {"field1": "value1", "field2": "value2", "field3": "value3"} + fields = ["field1", "field3"] + + copied = {k: v for k, v in source_meta.items() if k in fields} + + assert "field1" in copied + assert "field2" not in copied + assert "field3" in copied + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])