feat: Implement hierarchical retrieval architecture (#11610)
This PR implements the complete three-tier hierarchical retrieval architecture as specified in issue #11610, enabling production-grade RAG capabilities. ## Tier 1: Knowledge Base Routing - Auto-route queries to relevant knowledge bases - Per-KB retrieval parameters (KBRetrievalParams dataclass) - Rule-based routing with keyword overlap scoring - LLM-based routing with fallback to rule-based - Configurable routing methods: auto, rule_based, llm_based, all ## Tier 2: Document Filtering - Document-level metadata filtering within selected KBs - Configurable metadata fields for filtering - LLM-generated filter conditions - Metadata similarity matching (fuzzy matching) - Enhanced metadata generation for documents ## Tier 3: Chunk Refinement - Parent-child chunking with summary mapping - Custom prompts for keyword extraction - LLM-based question generation for chunks - Integration with existing retrieval pipeline ## Metadata Management (Batch CRUD) - MetadataService with batch operations: - batch_get_metadata - batch_update_metadata - batch_delete_metadata_fields - batch_set_metadata_field - get_metadata_schema - search_by_metadata - get_metadata_statistics - copy_metadata - REST API endpoints in metadata_app.py ## Integration - HierarchicalConfig dataclass for configuration - Integrated into Dealer class (search.py) - Wired into agent retrieval tool - Non-breaking: disabled by default ## Tests - 48 unit tests covering all components - Tests for config, routing, filtering, and metadata operations
This commit is contained in:
parent
c51e6b2a58
commit
d104f59e29
7 changed files with 2813 additions and 15 deletions
|
|
@ -63,6 +63,7 @@ class RetrievalParam(ToolParamBase):
|
||||||
self.cross_languages = []
|
self.cross_languages = []
|
||||||
self.toc_enhance = False
|
self.toc_enhance = False
|
||||||
self.meta_data_filter={}
|
self.meta_data_filter={}
|
||||||
|
self.hierarchical_retrieval = False # Enable hierarchical retrieval
|
||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||||
|
|
@ -174,20 +175,42 @@ class Retrieval(ToolBase, ABC):
|
||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
kbinfos = settings.retriever.retrieval(
|
|
||||||
query,
|
# Use hierarchical retrieval if enabled
|
||||||
embd_mdl,
|
if self._param.hierarchical_retrieval:
|
||||||
[kb.tenant_id for kb in kbs],
|
from rag.nlp.search import HierarchicalConfig
|
||||||
filtered_kb_ids,
|
kb_infos = [{"id": kb.id, "name": kb.name, "description": kb.description or ""} for kb in kbs]
|
||||||
1,
|
kbinfos = settings.retriever.hierarchical_retrieval(
|
||||||
self._param.top_n,
|
question=query,
|
||||||
self._param.similarity_threshold,
|
embd_mdl=embd_mdl,
|
||||||
1 - self._param.keywords_similarity_weight,
|
tenant_ids=[kb.tenant_id for kb in kbs],
|
||||||
doc_ids=doc_ids,
|
kb_ids=filtered_kb_ids,
|
||||||
aggs=False,
|
kb_infos=kb_infos,
|
||||||
rerank_mdl=rerank_mdl,
|
page=1,
|
||||||
rank_feature=label_question(query, kbs),
|
page_size=self._param.top_n,
|
||||||
)
|
similarity_threshold=self._param.similarity_threshold,
|
||||||
|
vector_similarity_weight=1 - self._param.keywords_similarity_weight,
|
||||||
|
doc_ids=doc_ids,
|
||||||
|
aggs=False,
|
||||||
|
rerank_mdl=rerank_mdl,
|
||||||
|
rank_feature=label_question(query, kbs),
|
||||||
|
hierarchical_config=HierarchicalConfig(enabled=True),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kbinfos = settings.retriever.retrieval(
|
||||||
|
query,
|
||||||
|
embd_mdl,
|
||||||
|
[kb.tenant_id for kb in kbs],
|
||||||
|
filtered_kb_ids,
|
||||||
|
1,
|
||||||
|
self._param.top_n,
|
||||||
|
self._param.similarity_threshold,
|
||||||
|
1 - self._param.keywords_similarity_weight,
|
||||||
|
doc_ids=doc_ids,
|
||||||
|
aggs=False,
|
||||||
|
rerank_mdl=rerank_mdl,
|
||||||
|
rank_feature=label_question(query, kbs),
|
||||||
|
)
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
378
api/apps/metadata_app.py
Normal file
378
api/apps/metadata_app.py
Normal file
|
|
@ -0,0 +1,378 @@
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Metadata Management API for Hierarchical Retrieval.
|
||||||
|
|
||||||
|
Provides REST endpoints for batch CRUD operations on document metadata,
|
||||||
|
supporting the hierarchical retrieval architecture's Tier 2 document filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from quart import request
|
||||||
|
from api.apps import current_user, login_required
|
||||||
|
from api.common.check_team_permission import check_kb_team_permission
|
||||||
|
from api.db.services.metadata_service import MetadataService
|
||||||
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
from api.utils.api_utils import (
|
||||||
|
get_json_result,
|
||||||
|
server_error_response,
|
||||||
|
)
|
||||||
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/batch/get", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def batch_get_metadata():
|
||||||
|
"""
|
||||||
|
Get metadata for multiple documents.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"doc_ids": ["doc1", "doc2", ...],
|
||||||
|
"fields": ["field1", "field2", ...] // optional
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"doc1": {"doc_id": "doc1", "doc_name": "...", "metadata": {...}},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
doc_ids = req.get("doc_ids", [])
|
||||||
|
fields = req.get("fields")
|
||||||
|
|
||||||
|
if not doc_ids:
|
||||||
|
return get_json_result(
|
||||||
|
data={},
|
||||||
|
message="No document IDs provided",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
result = MetadataService.batch_get_metadata(doc_ids, fields)
|
||||||
|
return get_json_result(data=result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/batch/update", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def batch_update_metadata():
|
||||||
|
"""
|
||||||
|
Update metadata for multiple documents.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"updates": [
|
||||||
|
{"doc_id": "doc1", "metadata": {"field1": "value1", ...}},
|
||||||
|
{"doc_id": "doc2", "metadata": {"field2": "value2", ...}},
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"merge": true // optional, default true. If false, replaces all metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success_count": 5,
|
||||||
|
"failed_ids": ["doc3"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
updates = req.get("updates", [])
|
||||||
|
merge = req.get("merge", True)
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
return get_json_result(
|
||||||
|
data={"success_count": 0, "failed_ids": []},
|
||||||
|
message="No updates provided",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count, failed_ids = MetadataService.batch_update_metadata(updates, merge)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failed_ids": failed_ids
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/batch/delete-fields", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def batch_delete_metadata_fields():
|
||||||
|
"""
|
||||||
|
Delete specific metadata fields from multiple documents.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"doc_ids": ["doc1", "doc2", ...],
|
||||||
|
"fields": ["field1", "field2", ...]
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success_count": 5,
|
||||||
|
"failed_ids": []
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
doc_ids = req.get("doc_ids", [])
|
||||||
|
fields = req.get("fields", [])
|
||||||
|
|
||||||
|
if not doc_ids or not fields:
|
||||||
|
return get_json_result(
|
||||||
|
data={"success_count": 0, "failed_ids": []},
|
||||||
|
message="doc_ids and fields are required",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count, failed_ids = MetadataService.batch_delete_metadata_fields(doc_ids, fields)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failed_ids": failed_ids
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/batch/set-field", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def batch_set_metadata_field():
|
||||||
|
"""
|
||||||
|
Set a specific field to the same value for multiple documents.
|
||||||
|
|
||||||
|
Useful for bulk categorization or tagging.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"doc_ids": ["doc1", "doc2", ...],
|
||||||
|
"field_name": "category",
|
||||||
|
"field_value": "Technical"
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success_count": 5,
|
||||||
|
"failed_ids": []
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
doc_ids = req.get("doc_ids", [])
|
||||||
|
field_name = req.get("field_name")
|
||||||
|
field_value = req.get("field_value")
|
||||||
|
|
||||||
|
if not doc_ids or not field_name:
|
||||||
|
return get_json_result(
|
||||||
|
data={"success_count": 0, "failed_ids": []},
|
||||||
|
message="doc_ids and field_name are required",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count, failed_ids = MetadataService.batch_set_metadata_field(
|
||||||
|
doc_ids, field_name, field_value
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failed_ids": failed_ids
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/schema/<kb_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_metadata_schema(kb_id):
|
||||||
|
"""
|
||||||
|
Get the metadata schema for a knowledge base.
|
||||||
|
|
||||||
|
Returns available metadata fields, their types, and sample values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"field1": {"type": "str", "sample_values": ["a", "b"], "count": 10},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check KB access permission
|
||||||
|
kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not kb:
|
||||||
|
return get_json_result(
|
||||||
|
data={},
|
||||||
|
message="Knowledge base not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
if not check_kb_team_permission(current_user.id, kb_id):
|
||||||
|
return get_json_result(
|
||||||
|
data={},
|
||||||
|
message="No permission to access this knowledge base",
|
||||||
|
code=RetCode.PERMISSION_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = MetadataService.get_metadata_schema(kb_id)
|
||||||
|
return get_json_result(data=schema)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/statistics/<kb_id>", methods=["GET"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def get_metadata_statistics(kb_id):
|
||||||
|
"""
|
||||||
|
Get statistics about metadata usage in a knowledge base.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"total_documents": 100,
|
||||||
|
"documents_with_metadata": 80,
|
||||||
|
"metadata_coverage": 0.8,
|
||||||
|
"field_usage": {"category": 50, "author": 30},
|
||||||
|
"unique_fields": 5
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check KB access permission
|
||||||
|
kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
|
if not kb:
|
||||||
|
return get_json_result(
|
||||||
|
data={},
|
||||||
|
message="Knowledge base not found",
|
||||||
|
code=RetCode.DATA_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
if not check_kb_team_permission(current_user.id, kb_id):
|
||||||
|
return get_json_result(
|
||||||
|
data={},
|
||||||
|
message="No permission to access this knowledge base",
|
||||||
|
code=RetCode.PERMISSION_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = MetadataService.get_metadata_statistics(kb_id)
|
||||||
|
return get_json_result(data=stats)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/search", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def search_by_metadata():
|
||||||
|
"""
|
||||||
|
Search documents by metadata filters.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"kb_id": "kb123",
|
||||||
|
"filters": {
|
||||||
|
"category": "Technical",
|
||||||
|
"author": {"contains": "John"},
|
||||||
|
"year": {"gt": 2020}
|
||||||
|
},
|
||||||
|
"limit": 100
|
||||||
|
}
|
||||||
|
|
||||||
|
Supported operators: equals, contains, starts_with, in, gt, lt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[
|
||||||
|
{"doc_id": "doc1", "doc_name": "...", "metadata": {...}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
kb_id = req.get("kb_id")
|
||||||
|
filters = req.get("filters", {})
|
||||||
|
limit = req.get("limit", 100)
|
||||||
|
|
||||||
|
if not kb_id:
|
||||||
|
return get_json_result(
|
||||||
|
data=[],
|
||||||
|
message="kb_id is required",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check KB access permission
|
||||||
|
if not check_kb_team_permission(current_user.id, kb_id):
|
||||||
|
return get_json_result(
|
||||||
|
data=[],
|
||||||
|
message="No permission to access this knowledge base",
|
||||||
|
code=RetCode.PERMISSION_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
results = MetadataService.search_by_metadata(kb_id, filters, limit)
|
||||||
|
return get_json_result(data=results)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/copy", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def copy_metadata():
|
||||||
|
"""
|
||||||
|
Copy metadata from one document to multiple target documents.
|
||||||
|
|
||||||
|
Request body:
|
||||||
|
{
|
||||||
|
"source_doc_id": "doc1",
|
||||||
|
"target_doc_ids": ["doc2", "doc3", ...],
|
||||||
|
"fields": ["field1", "field2"] // optional, copies all if not specified
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success_count": 5,
|
||||||
|
"failed_ids": []
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
req = await request.json
|
||||||
|
source_doc_id = req.get("source_doc_id")
|
||||||
|
target_doc_ids = req.get("target_doc_ids", [])
|
||||||
|
fields = req.get("fields")
|
||||||
|
|
||||||
|
if not source_doc_id or not target_doc_ids:
|
||||||
|
return get_json_result(
|
||||||
|
data={"success_count": 0, "failed_ids": []},
|
||||||
|
message="source_doc_id and target_doc_ids are required",
|
||||||
|
code=RetCode.ARGUMENT_ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count, failed_ids = MetadataService.copy_metadata(
|
||||||
|
source_doc_id, target_doc_ids, fields
|
||||||
|
)
|
||||||
|
|
||||||
|
return get_json_result(data={
|
||||||
|
"success_count": success_count,
|
||||||
|
"failed_ids": failed_ids
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return server_error_response(e)
|
||||||
398
api/db/services/metadata_service.py
Normal file
398
api/db/services/metadata_service.py
Normal file
|
|
@ -0,0 +1,398 @@
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Metadata Management Service for Hierarchical Retrieval.
|
||||||
|
|
||||||
|
Provides batch CRUD operations for document metadata to support:
|
||||||
|
- Efficient metadata filtering in Tier 2 of hierarchical retrieval
|
||||||
|
- Bulk metadata updates across multiple documents
|
||||||
|
- Metadata schema management per knowledge base
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
from peewee import fn
|
||||||
|
|
||||||
|
from api.db.db_models import DB, Document
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataService:
|
||||||
|
"""
|
||||||
|
Service for managing document metadata in batch operations.
|
||||||
|
|
||||||
|
Supports the hierarchical retrieval architecture by providing
|
||||||
|
efficient metadata management for document filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def batch_get_metadata(
|
||||||
|
cls,
|
||||||
|
doc_ids: List[str],
|
||||||
|
fields: Optional[List[str]] = None
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get metadata for multiple documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_ids: List of document IDs
|
||||||
|
fields: Optional list of specific metadata fields to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping doc_id to metadata dict
|
||||||
|
"""
|
||||||
|
if not doc_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
docs = Document.select(
|
||||||
|
Document.id,
|
||||||
|
Document.meta_fields,
|
||||||
|
Document.name
|
||||||
|
).where(Document.id.in_(doc_ids))
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
meta = doc.meta_fields or {}
|
||||||
|
if fields:
|
||||||
|
# Filter to requested fields only
|
||||||
|
meta = {k: v for k, v in meta.items() if k in fields}
|
||||||
|
result[doc.id] = {
|
||||||
|
"doc_id": doc.id,
|
||||||
|
"doc_name": doc.name,
|
||||||
|
"metadata": meta
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def batch_update_metadata(
|
||||||
|
cls,
|
||||||
|
updates: List[Dict[str, Any]],
|
||||||
|
merge: bool = True
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
Update metadata for multiple documents in batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updates: List of dicts with 'doc_id' and 'metadata' keys
|
||||||
|
merge: If True, merge with existing metadata; if False, replace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success_count, list of failed doc_ids)
|
||||||
|
"""
|
||||||
|
if not updates:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
|
||||||
|
for update in updates:
|
||||||
|
doc_id = update.get("doc_id")
|
||||||
|
new_metadata = update.get("metadata", {})
|
||||||
|
|
||||||
|
if not doc_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if merge:
|
||||||
|
# Get existing metadata and merge
|
||||||
|
doc = Document.get_or_none(Document.id == doc_id)
|
||||||
|
if doc:
|
||||||
|
existing = doc.meta_fields or {}
|
||||||
|
existing.update(new_metadata)
|
||||||
|
new_metadata = existing
|
||||||
|
|
||||||
|
DocumentService.update_meta_fields(doc_id, new_metadata)
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to update metadata for doc {doc_id}: {e}")
|
||||||
|
failed_ids.append(doc_id)
|
||||||
|
|
||||||
|
logging.info(f"Batch metadata update: {success_count} succeeded, {len(failed_ids)} failed")
|
||||||
|
return success_count, failed_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def batch_delete_metadata_fields(
|
||||||
|
cls,
|
||||||
|
doc_ids: List[str],
|
||||||
|
fields: List[str]
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
Delete specific metadata fields from multiple documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_ids: List of document IDs
|
||||||
|
fields: List of metadata field names to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success_count, list of failed doc_ids)
|
||||||
|
"""
|
||||||
|
if not doc_ids or not fields:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
failed_ids = []
|
||||||
|
|
||||||
|
docs = Document.select(
|
||||||
|
Document.id,
|
||||||
|
Document.meta_fields
|
||||||
|
).where(Document.id.in_(doc_ids))
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
try:
|
||||||
|
meta = doc.meta_fields or {}
|
||||||
|
modified = False
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
if field in meta:
|
||||||
|
del meta[field]
|
||||||
|
modified = True
|
||||||
|
|
||||||
|
if modified:
|
||||||
|
DocumentService.update_meta_fields(doc.id, meta)
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to delete metadata fields for doc {doc.id}: {e}")
|
||||||
|
failed_ids.append(doc.id)
|
||||||
|
|
||||||
|
return success_count, failed_ids
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def batch_set_metadata_field(
|
||||||
|
cls,
|
||||||
|
doc_ids: List[str],
|
||||||
|
field_name: str,
|
||||||
|
field_value: Any
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
Set a specific metadata field to the same value for multiple documents.
|
||||||
|
|
||||||
|
Useful for bulk categorization or tagging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_ids: List of document IDs
|
||||||
|
field_name: Name of the metadata field
|
||||||
|
field_value: Value to set
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success_count, list of failed doc_ids)
|
||||||
|
"""
|
||||||
|
if not doc_ids or not field_name:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
updates = [
|
||||||
|
{"doc_id": doc_id, "metadata": {field_name: field_value}}
|
||||||
|
for doc_id in doc_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return cls.batch_update_metadata(updates, merge=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_metadata_schema(cls, kb_id: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get the metadata schema for a knowledge base.
|
||||||
|
|
||||||
|
Analyzes all documents in the KB to determine available
|
||||||
|
metadata fields and their types/values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: Knowledge base ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to field info (type, sample values, count)
|
||||||
|
"""
|
||||||
|
schema = {}
|
||||||
|
|
||||||
|
docs = Document.select(
|
||||||
|
Document.meta_fields
|
||||||
|
).where(Document.kb_id == kb_id)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
meta = doc.meta_fields or {}
|
||||||
|
for field_name, field_value in meta.items():
|
||||||
|
if field_name not in schema:
|
||||||
|
schema[field_name] = {
|
||||||
|
"type": type(field_value).__name__,
|
||||||
|
"sample_values": set(),
|
||||||
|
"count": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
schema[field_name]["count"] += 1
|
||||||
|
|
||||||
|
# Collect sample values (limit to 10)
|
||||||
|
if len(schema[field_name]["sample_values"]) < 10:
|
||||||
|
try:
|
||||||
|
schema[field_name]["sample_values"].add(str(field_value)[:100])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Convert sets to lists for JSON serialization
|
||||||
|
for field_name in schema:
|
||||||
|
schema[field_name]["sample_values"] = list(schema[field_name]["sample_values"])
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def search_by_metadata(
|
||||||
|
cls,
|
||||||
|
kb_id: str,
|
||||||
|
filters: Dict[str, Any],
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Search documents by metadata filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: Knowledge base ID
|
||||||
|
filters: Dict of field_name -> value or {operator: value}
|
||||||
|
limit: Maximum number of results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching documents with their metadata
|
||||||
|
"""
|
||||||
|
docs = Document.select(
|
||||||
|
Document.id,
|
||||||
|
Document.name,
|
||||||
|
Document.meta_fields
|
||||||
|
).where(Document.kb_id == kb_id)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for doc in docs:
|
||||||
|
meta = doc.meta_fields or {}
|
||||||
|
matches = True
|
||||||
|
|
||||||
|
for field_name, condition in filters.items():
|
||||||
|
doc_value = meta.get(field_name)
|
||||||
|
|
||||||
|
if isinstance(condition, dict):
|
||||||
|
# Operator-based condition
|
||||||
|
op = list(condition.keys())[0]
|
||||||
|
val = condition[op]
|
||||||
|
|
||||||
|
if op == "equals":
|
||||||
|
matches = str(doc_value) == str(val)
|
||||||
|
elif op == "contains":
|
||||||
|
matches = str(val).lower() in str(doc_value).lower()
|
||||||
|
elif op == "starts_with":
|
||||||
|
matches = str(doc_value).lower().startswith(str(val).lower())
|
||||||
|
elif op == "in":
|
||||||
|
matches = doc_value in val
|
||||||
|
elif op == "gt":
|
||||||
|
matches = float(doc_value) > float(val) if doc_value else False
|
||||||
|
elif op == "lt":
|
||||||
|
matches = float(doc_value) < float(val) if doc_value else False
|
||||||
|
else:
|
||||||
|
# Simple equality
|
||||||
|
matches = str(doc_value) == str(condition)
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
break
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
results.append({
|
||||||
|
"doc_id": doc.id,
|
||||||
|
"doc_name": doc.name,
|
||||||
|
"metadata": meta
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_metadata_statistics(cls, kb_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics about metadata usage in a knowledge base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: Knowledge base ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with statistics about metadata fields
|
||||||
|
"""
|
||||||
|
total_docs = Document.select(fn.COUNT(Document.id)).where(
|
||||||
|
Document.kb_id == kb_id
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
docs_with_metadata = 0
|
||||||
|
field_usage = {}
|
||||||
|
|
||||||
|
docs = Document.select(Document.meta_fields).where(Document.kb_id == kb_id)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
meta = doc.meta_fields or {}
|
||||||
|
if meta:
|
||||||
|
docs_with_metadata += 1
|
||||||
|
for field_name in meta.keys():
|
||||||
|
field_usage[field_name] = field_usage.get(field_name, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_documents": total_docs,
|
||||||
|
"documents_with_metadata": docs_with_metadata,
|
||||||
|
"metadata_coverage": docs_with_metadata / total_docs if total_docs > 0 else 0,
|
||||||
|
"field_usage": field_usage,
|
||||||
|
"unique_fields": len(field_usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def copy_metadata(
|
||||||
|
cls,
|
||||||
|
source_doc_id: str,
|
||||||
|
target_doc_ids: List[str],
|
||||||
|
fields: Optional[List[str]] = None
|
||||||
|
) -> Tuple[int, List[str]]:
|
||||||
|
"""
|
||||||
|
Copy metadata from one document to multiple target documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_doc_id: Source document ID
|
||||||
|
target_doc_ids: List of target document IDs
|
||||||
|
fields: Optional list of specific fields to copy (all if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (success_count, list of failed doc_ids)
|
||||||
|
"""
|
||||||
|
source_doc = Document.get_or_none(Document.id == source_doc_id)
|
||||||
|
if not source_doc:
|
||||||
|
return 0, target_doc_ids
|
||||||
|
|
||||||
|
source_meta = source_doc.meta_fields or {}
|
||||||
|
|
||||||
|
if fields:
|
||||||
|
source_meta = {k: v for k, v in source_meta.items() if k in fields}
|
||||||
|
|
||||||
|
if not source_meta:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
updates = [
|
||||||
|
{"doc_id": doc_id, "metadata": source_meta.copy()}
|
||||||
|
for doc_id in target_doc_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
return cls.batch_update_metadata(updates, merge=True)
|
||||||
1056
rag/nlp/search.py
1056
rag/nlp/search.py
File diff suppressed because it is too large
Load diff
16
test/unit_test/nlp/__init__.py
Normal file
16
test/unit_test/nlp/__init__.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""Unit tests for NLP module."""
|
||||||
699
test/unit_test/nlp/test_hierarchical_retrieval.py
Normal file
699
test/unit_test/nlp/test_hierarchical_retrieval.py
Normal file
|
|
@ -0,0 +1,699 @@
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for hierarchical retrieval integration in search.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from rag.nlp.search import (
|
||||||
|
HierarchicalConfig,
|
||||||
|
HierarchicalResult,
|
||||||
|
KBRetrievalParams,
|
||||||
|
Dealer,
|
||||||
|
index_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalConfig:
|
||||||
|
"""Test HierarchicalConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = HierarchicalConfig()
|
||||||
|
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.enable_kb_routing is True
|
||||||
|
assert config.kb_routing_threshold == 0.3
|
||||||
|
assert config.kb_top_k == 3
|
||||||
|
assert config.enable_doc_filtering is True
|
||||||
|
assert config.doc_top_k == 100
|
||||||
|
assert config.chunk_top_k == 10
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
"""Test custom configuration values."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
kb_routing_threshold=0.5,
|
||||||
|
kb_top_k=5,
|
||||||
|
doc_top_k=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.kb_routing_threshold == 0.5
|
||||||
|
assert config.kb_top_k == 5
|
||||||
|
assert config.doc_top_k == 50
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalResult:
|
||||||
|
"""Test HierarchicalResult dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default result values."""
|
||||||
|
result = HierarchicalResult()
|
||||||
|
|
||||||
|
assert result.selected_kb_ids == []
|
||||||
|
assert result.filtered_doc_ids == []
|
||||||
|
assert result.tier1_time_ms == 0.0
|
||||||
|
assert result.tier2_time_ms == 0.0
|
||||||
|
assert result.tier3_time_ms == 0.0
|
||||||
|
assert result.total_time_ms == 0.0
|
||||||
|
|
||||||
|
def test_populated_result(self):
|
||||||
|
"""Test result with data."""
|
||||||
|
result = HierarchicalResult(
|
||||||
|
selected_kb_ids=["kb1", "kb2"],
|
||||||
|
filtered_doc_ids=["doc1", "doc2", "doc3"],
|
||||||
|
tier1_time_ms=10.5,
|
||||||
|
tier2_time_ms=25.3,
|
||||||
|
tier3_time_ms=100.2,
|
||||||
|
total_time_ms=136.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.selected_kb_ids) == 2
|
||||||
|
assert len(result.filtered_doc_ids) == 3
|
||||||
|
assert result.total_time_ms == 136.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTier1KBRouting:
|
||||||
|
"""Test Tier 1: KB Routing logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_routing_disabled(self, mock_dealer):
|
||||||
|
"""Test that routing returns all KBs when disabled."""
|
||||||
|
config = HierarchicalConfig(enable_kb_routing=False)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3"]
|
||||||
|
|
||||||
|
result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config)
|
||||||
|
|
||||||
|
assert result == kb_ids
|
||||||
|
|
||||||
|
def test_routing_no_kb_infos(self, mock_dealer):
|
||||||
|
"""Test that routing returns all KBs when no KB info provided."""
|
||||||
|
config = HierarchicalConfig(enable_kb_routing=True)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3"]
|
||||||
|
|
||||||
|
result = mock_dealer._tier1_kb_routing("test query", kb_ids, None, config)
|
||||||
|
|
||||||
|
assert result == kb_ids
|
||||||
|
|
||||||
|
def test_routing_few_kbs(self, mock_dealer):
|
||||||
|
"""Test that routing returns all KBs when count <= kb_top_k."""
|
||||||
|
config = HierarchicalConfig(enable_kb_routing=True, kb_top_k=5)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3"]
|
||||||
|
kb_infos = [
|
||||||
|
{"id": "kb1", "name": "Finance KB", "description": "Financial documents"},
|
||||||
|
{"id": "kb2", "name": "HR KB", "description": "Human resources"},
|
||||||
|
{"id": "kb3", "name": "Tech KB", "description": "Technical docs"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer._tier1_kb_routing("test query", kb_ids, kb_infos, config)
|
||||||
|
|
||||||
|
assert result == kb_ids
|
||||||
|
|
||||||
|
def test_routing_selects_relevant_kbs(self, mock_dealer):
|
||||||
|
"""Test that routing selects KBs based on keyword overlap."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_top_k=2,
|
||||||
|
kb_routing_threshold=0.1
|
||||||
|
)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
|
||||||
|
kb_infos = [
|
||||||
|
{"id": "kb1", "name": "Finance Reports", "description": "Financial analysis and reports"},
|
||||||
|
{"id": "kb2", "name": "HR Policies", "description": "Human resources policies"},
|
||||||
|
{"id": "kb3", "name": "Technical Documentation", "description": "Engineering docs"},
|
||||||
|
{"id": "kb4", "name": "Financial Statements", "description": "Quarterly financial data"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Query about finance should select finance-related KBs
|
||||||
|
result = mock_dealer._tier1_kb_routing("financial report analysis", kb_ids, kb_infos, config)
|
||||||
|
|
||||||
|
# Should select at most kb_top_k KBs
|
||||||
|
assert len(result) <= config.kb_top_k
|
||||||
|
# Should include finance-related KBs
|
||||||
|
assert any(kb in result for kb in ["kb1", "kb4"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestTier2DocumentFiltering:
|
||||||
|
"""Test Tier 2: Document Filtering logic."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_filtering_disabled(self, mock_dealer):
|
||||||
|
"""Test that filtering returns empty when disabled."""
|
||||||
|
config = HierarchicalConfig(enable_doc_filtering=False)
|
||||||
|
|
||||||
|
result = mock_dealer._tier2_document_filtering(
|
||||||
|
"test query", ["tenant1"], ["kb1"], None, config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_filtering_with_existing_doc_ids(self, mock_dealer):
|
||||||
|
"""Test that filtering limits existing doc_ids."""
|
||||||
|
config = HierarchicalConfig(enable_doc_filtering=True, doc_top_k=2)
|
||||||
|
doc_ids = ["doc1", "doc2", "doc3", "doc4"]
|
||||||
|
|
||||||
|
result = mock_dealer._tier2_document_filtering(
|
||||||
|
"test query", ["tenant1"], ["kb1"], doc_ids, config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result == ["doc1", "doc2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalRetrieval:
|
||||||
|
"""Test the full hierarchical retrieval flow."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance with mocked methods."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
|
||||||
|
# Mock the retrieval method
|
||||||
|
dealer.retrieval = Mock(return_value={
|
||||||
|
"total": 5,
|
||||||
|
"chunks": [
|
||||||
|
{"chunk_id": "c1", "content_with_weight": "test content 1"},
|
||||||
|
{"chunk_id": "c2", "content_with_weight": "test content 2"},
|
||||||
|
],
|
||||||
|
"doc_aggs": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_hierarchical_retrieval_basic(self, mock_dealer):
|
||||||
|
"""Test basic hierarchical retrieval flow."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
enable_kb_routing=False, # Skip routing for this test
|
||||||
|
enable_doc_filtering=False, # Skip filtering for this test
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mock_dealer.hierarchical_retrieval(
|
||||||
|
question="test query",
|
||||||
|
embd_mdl=Mock(),
|
||||||
|
tenant_ids=["tenant1"],
|
||||||
|
kb_ids=["kb1", "kb2"],
|
||||||
|
hierarchical_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have chunks from retrieval
|
||||||
|
assert "chunks" in result
|
||||||
|
assert len(result["chunks"]) == 2
|
||||||
|
|
||||||
|
# Should have hierarchical metadata
|
||||||
|
assert "hierarchical_metadata" in result
|
||||||
|
metadata = result["hierarchical_metadata"]
|
||||||
|
assert "tier1_time_ms" in metadata
|
||||||
|
assert "tier2_time_ms" in metadata
|
||||||
|
assert "tier3_time_ms" in metadata
|
||||||
|
assert "total_time_ms" in metadata
|
||||||
|
|
||||||
|
def test_hierarchical_retrieval_with_kb_infos(self, mock_dealer):
|
||||||
|
"""Test hierarchical retrieval with KB information."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_top_k=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_infos = [
|
||||||
|
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
|
||||||
|
{"id": "kb2", "name": "HR", "description": "HR policies"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer.hierarchical_retrieval(
|
||||||
|
question="financial report",
|
||||||
|
embd_mdl=Mock(),
|
||||||
|
tenant_ids=["tenant1"],
|
||||||
|
kb_ids=["kb1", "kb2"],
|
||||||
|
kb_infos=kb_infos,
|
||||||
|
hierarchical_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "hierarchical_metadata" in result
|
||||||
|
assert "selected_kb_ids" in result["hierarchical_metadata"]
|
||||||
|
|
||||||
|
def test_hierarchical_retrieval_empty_query(self, mock_dealer):
|
||||||
|
"""Test hierarchical retrieval with empty query."""
|
||||||
|
# Mock retrieval to return empty for empty query
|
||||||
|
mock_dealer.retrieval = Mock(return_value={
|
||||||
|
"total": 0,
|
||||||
|
"chunks": [],
|
||||||
|
"doc_aggs": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
config = HierarchicalConfig(enabled=True)
|
||||||
|
|
||||||
|
result = mock_dealer.hierarchical_retrieval(
|
||||||
|
question="",
|
||||||
|
embd_mdl=Mock(),
|
||||||
|
tenant_ids=["tenant1"],
|
||||||
|
kb_ids=["kb1"],
|
||||||
|
hierarchical_config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["total"] == 0
|
||||||
|
assert result["chunks"] == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestIndexName:
|
||||||
|
"""Test index_name utility function."""
|
||||||
|
|
||||||
|
def test_index_name_format(self):
|
||||||
|
"""Test index name formatting."""
|
||||||
|
assert index_name("user123") == "ragflow_user123"
|
||||||
|
assert index_name("tenant_abc") == "ragflow_tenant_abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHierarchicalConfigAdvanced:
|
||||||
|
"""Test advanced HierarchicalConfig options."""
|
||||||
|
|
||||||
|
def test_all_config_options(self):
|
||||||
|
"""Test all configuration options."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_routing_method="llm_based",
|
||||||
|
kb_routing_threshold=0.5,
|
||||||
|
kb_top_k=5,
|
||||||
|
enable_doc_filtering=True,
|
||||||
|
doc_top_k=50,
|
||||||
|
metadata_fields=["department", "doc_type"],
|
||||||
|
enable_metadata_similarity=True,
|
||||||
|
metadata_similarity_threshold=0.8,
|
||||||
|
use_llm_metadata_filter=True,
|
||||||
|
chunk_top_k=20,
|
||||||
|
enable_parent_child=True,
|
||||||
|
use_summary_mapping=True,
|
||||||
|
keyword_extraction_prompt="Extract important technical terms",
|
||||||
|
question_generation_prompt="Generate questions about the content",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.kb_routing_method == "llm_based"
|
||||||
|
assert config.metadata_fields == ["department", "doc_type"]
|
||||||
|
assert config.enable_metadata_similarity is True
|
||||||
|
assert config.use_llm_metadata_filter is True
|
||||||
|
assert config.enable_parent_child is True
|
||||||
|
assert config.use_summary_mapping is True
|
||||||
|
assert config.keyword_extraction_prompt is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMKBRouting:
|
||||||
|
"""Test LLM-based KB routing."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_llm_routing_fallback_no_model(self, mock_dealer):
|
||||||
|
"""Test LLM routing falls back when no model provided."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_routing_method="llm_based",
|
||||||
|
kb_top_k=2
|
||||||
|
)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
|
||||||
|
kb_infos = [
|
||||||
|
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
|
||||||
|
{"id": "kb2", "name": "HR", "description": "HR policies"},
|
||||||
|
{"id": "kb3", "name": "Tech", "description": "Technical docs"},
|
||||||
|
{"id": "kb4", "name": "Legal", "description": "Legal documents"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Without chat_mdl, should fall back to rule-based
|
||||||
|
result = mock_dealer._tier1_kb_routing(
|
||||||
|
"financial report", kb_ids, kb_infos, config, chat_mdl=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still return results (from rule-based fallback)
|
||||||
|
assert len(result) <= config.kb_top_k
|
||||||
|
|
||||||
|
def test_routing_method_all(self, mock_dealer):
|
||||||
|
"""Test 'all' routing method returns all KBs."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_routing_method="all",
|
||||||
|
kb_top_k=2
|
||||||
|
)
|
||||||
|
kb_ids = ["kb1", "kb2", "kb3", "kb4"]
|
||||||
|
kb_infos = [
|
||||||
|
{"id": "kb1", "name": "Finance", "description": "Financial docs"},
|
||||||
|
{"id": "kb2", "name": "HR", "description": "HR policies"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer._tier1_kb_routing(
|
||||||
|
"any query", kb_ids, kb_infos, config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == kb_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataSimilarityFilter:
|
||||||
|
"""Test metadata similarity filtering."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_similarity_filter_no_model(self, mock_dealer):
|
||||||
|
"""Test similarity filter returns empty without embedding model."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enable_metadata_similarity=True,
|
||||||
|
metadata_similarity_threshold=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_metadata = [
|
||||||
|
{"id": "doc1", "name": "Finance Report", "summary": "Q1 financial analysis"},
|
||||||
|
{"id": "doc2", "name": "HR Policy", "summary": "Employee guidelines"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer._metadata_similarity_filter(
|
||||||
|
"financial analysis", doc_metadata, config, embd_mdl=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_similarity_filter_no_metadata(self, mock_dealer):
|
||||||
|
"""Test similarity filter returns empty without metadata."""
|
||||||
|
config = HierarchicalConfig(enable_metadata_similarity=True)
|
||||||
|
mock_embd = Mock()
|
||||||
|
|
||||||
|
result = mock_dealer._metadata_similarity_filter(
|
||||||
|
"test query", [], config, embd_mdl=mock_embd
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestParentChildRetrieval:
|
||||||
|
"""Test parent-child chunking with summary mapping."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance with mocked methods."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
|
||||||
|
# Mock the search method
|
||||||
|
mock_search_result = Mock()
|
||||||
|
mock_search_result.ids = ["chunk1", "chunk2"]
|
||||||
|
mock_search_result.field = {
|
||||||
|
"chunk1": {
|
||||||
|
"content_ltks": "parent content",
|
||||||
|
"content_with_weight": "parent content",
|
||||||
|
"doc_id": "doc1",
|
||||||
|
"docnm_kwd": "test.pdf",
|
||||||
|
"kb_id": "kb1",
|
||||||
|
"mom_id": "",
|
||||||
|
},
|
||||||
|
"chunk2": {
|
||||||
|
"content_ltks": "child content",
|
||||||
|
"content_with_weight": "child content",
|
||||||
|
"doc_id": "doc1",
|
||||||
|
"docnm_kwd": "test.pdf",
|
||||||
|
"kb_id": "kb1",
|
||||||
|
"mom_id": "chunk1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mock_search_result.total = 2
|
||||||
|
mock_search_result.query_vector = [0.1] * 768
|
||||||
|
mock_search_result.highlight = {}
|
||||||
|
|
||||||
|
dealer.search = Mock(return_value=mock_search_result)
|
||||||
|
dealer.rerank = Mock(return_value=([0.9, 0.8], [0.5, 0.4], [0.9, 0.8]))
|
||||||
|
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_summary_mapping_config(self):
|
||||||
|
"""Test summary mapping configuration."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enable_parent_child=True,
|
||||||
|
use_summary_mapping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enable_parent_child is True
|
||||||
|
assert config.use_summary_mapping is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomKeywordExtraction:
|
||||||
|
"""Test customizable keyword extraction."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_keyword_extraction_important(self, mock_dealer):
|
||||||
|
"""Test keyword extraction with 'important' prompt."""
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"content_with_weight": "The API_KEY and DatabaseConnection are critical components.",
|
||||||
|
"important_kwd": ["existing_term"], # Pre-existing keyword
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer._apply_custom_keyword_extraction(
|
||||||
|
chunks, "Extract important technical terms", None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return the chunk (possibly with extracted keywords)
|
||||||
|
assert len(result) == 1
|
||||||
|
# Should preserve existing keywords at minimum
|
||||||
|
assert "existing_term" in result[0].get("important_kwd", [])
|
||||||
|
|
||||||
|
def test_keyword_extraction_question(self, mock_dealer):
|
||||||
|
"""Test keyword extraction with 'question' prompt."""
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"content_with_weight": "This section explains the authentication flow for the system. Users must provide valid credentials.",
|
||||||
|
"important_kwd": [],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = mock_dealer._apply_custom_keyword_extraction(
|
||||||
|
chunks, "Generate questions about the content", None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
# Should have a question hint
|
||||||
|
assert "question_hint" in result[0]
|
||||||
|
|
||||||
|
def test_keyword_extraction_empty_content(self, mock_dealer):
|
||||||
|
"""Test keyword extraction with empty content."""
|
||||||
|
chunks = [{"content_with_weight": "", "important_kwd": []}]
|
||||||
|
|
||||||
|
result = mock_dealer._apply_custom_keyword_extraction(
|
||||||
|
chunks, "Extract important terms", None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBRetrievalParams:
|
||||||
|
"""Test per-KB retrieval parameters."""
|
||||||
|
|
||||||
|
def test_default_params(self):
|
||||||
|
"""Test default KB params."""
|
||||||
|
params = KBRetrievalParams(kb_id="kb1")
|
||||||
|
|
||||||
|
assert params.kb_id == "kb1"
|
||||||
|
assert params.vector_similarity_weight == 0.7
|
||||||
|
assert params.similarity_threshold == 0.2
|
||||||
|
assert params.top_k == 1024
|
||||||
|
assert params.rerank_enabled is True
|
||||||
|
|
||||||
|
def test_custom_params(self):
|
||||||
|
"""Test custom KB params."""
|
||||||
|
params = KBRetrievalParams(
|
||||||
|
kb_id="finance_kb",
|
||||||
|
vector_similarity_weight=0.9,
|
||||||
|
similarity_threshold=0.3,
|
||||||
|
top_k=500,
|
||||||
|
rerank_enabled=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert params.kb_id == "finance_kb"
|
||||||
|
assert params.vector_similarity_weight == 0.9
|
||||||
|
assert params.similarity_threshold == 0.3
|
||||||
|
assert params.top_k == 500
|
||||||
|
assert params.rerank_enabled is False
|
||||||
|
|
||||||
|
def test_kb_params_in_config(self):
|
||||||
|
"""Test KB params integration in HierarchicalConfig."""
|
||||||
|
kb_params = {
|
||||||
|
"kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.8),
|
||||||
|
"kb2": KBRetrievalParams(kb_id="kb2", similarity_threshold=0.4),
|
||||||
|
}
|
||||||
|
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
kb_params=kb_params
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(config.kb_params) == 2
|
||||||
|
assert config.kb_params["kb1"].vector_similarity_weight == 0.8
|
||||||
|
assert config.kb_params["kb2"].similarity_threshold == 0.4
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMQuestionGeneration:
|
||||||
|
"""Test LLM-based question generation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_question_generation_no_model(self, mock_dealer):
|
||||||
|
"""Test question generation returns chunks unchanged without model."""
|
||||||
|
chunks = [{"content_with_weight": "Test content", "important_kwd": []}]
|
||||||
|
|
||||||
|
result = mock_dealer._apply_llm_question_generation(chunks, None, None)
|
||||||
|
|
||||||
|
assert result == chunks
|
||||||
|
assert "generated_questions" not in result[0]
|
||||||
|
|
||||||
|
def test_question_generation_with_mock_model(self, mock_dealer):
|
||||||
|
"""Test question generation with mock LLM."""
|
||||||
|
mock_chat = Mock()
|
||||||
|
mock_chat.chat = Mock(return_value="What is the main topic?\nHow does this work?")
|
||||||
|
|
||||||
|
chunks = [{
|
||||||
|
"content_with_weight": "This is a detailed explanation of the authentication system. It uses OAuth2 for secure access.",
|
||||||
|
"important_kwd": []
|
||||||
|
}]
|
||||||
|
|
||||||
|
result = mock_dealer._apply_llm_question_generation(chunks, None, mock_chat)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "generated_questions" in result[0]
|
||||||
|
assert len(result[0]["generated_questions"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentMetadataGeneration:
|
||||||
|
"""Test document metadata generation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dealer(self):
|
||||||
|
"""Create a mock Dealer instance."""
|
||||||
|
mock_datastore = Mock()
|
||||||
|
dealer = Dealer(mock_datastore)
|
||||||
|
return dealer
|
||||||
|
|
||||||
|
def test_metadata_generation_no_model(self, mock_dealer):
|
||||||
|
"""Test metadata generation returns empty without model."""
|
||||||
|
result = mock_dealer.generate_document_metadata("doc1", "content", None, None)
|
||||||
|
|
||||||
|
assert result["doc_id"] == "doc1"
|
||||||
|
assert result["generated"] is True
|
||||||
|
assert result["summary"] == ""
|
||||||
|
assert result["topics"] == []
|
||||||
|
|
||||||
|
def test_metadata_generation_with_mock_model(self, mock_dealer):
|
||||||
|
"""Test metadata generation with mock LLM."""
|
||||||
|
mock_chat = Mock()
|
||||||
|
mock_chat.chat = Mock(return_value="""SUMMARY: This document explains authentication.
|
||||||
|
TOPICS: security, OAuth, authentication
|
||||||
|
QUESTIONS:
|
||||||
|
- What is OAuth?
|
||||||
|
- How does authentication work?
|
||||||
|
CATEGORY: Technical Documentation""")
|
||||||
|
|
||||||
|
result = mock_dealer.generate_document_metadata(
|
||||||
|
"doc1",
|
||||||
|
"This is content about authentication and security.",
|
||||||
|
mock_chat,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["doc_id"] == "doc1"
|
||||||
|
assert "authentication" in result["summary"]
|
||||||
|
assert len(result["topics"]) == 3
|
||||||
|
assert "security" in result["topics"]
|
||||||
|
assert len(result["suggested_questions"]) == 2
|
||||||
|
assert result["category"] == "Technical Documentation"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullHierarchicalConfig:
|
||||||
|
"""Test complete hierarchical configuration."""
|
||||||
|
|
||||||
|
def test_all_features_enabled(self):
|
||||||
|
"""Test config with all features enabled."""
|
||||||
|
config = HierarchicalConfig(
|
||||||
|
enabled=True,
|
||||||
|
# Tier 1
|
||||||
|
enable_kb_routing=True,
|
||||||
|
kb_routing_method="llm_based",
|
||||||
|
kb_routing_threshold=0.4,
|
||||||
|
kb_top_k=5,
|
||||||
|
kb_params={
|
||||||
|
"kb1": KBRetrievalParams(kb_id="kb1", vector_similarity_weight=0.9)
|
||||||
|
},
|
||||||
|
# Tier 2
|
||||||
|
enable_doc_filtering=True,
|
||||||
|
doc_top_k=50,
|
||||||
|
metadata_fields=["department", "author"],
|
||||||
|
enable_metadata_similarity=True,
|
||||||
|
metadata_similarity_threshold=0.8,
|
||||||
|
use_llm_metadata_filter=True,
|
||||||
|
# Tier 3
|
||||||
|
chunk_top_k=20,
|
||||||
|
enable_parent_child=True,
|
||||||
|
use_summary_mapping=True,
|
||||||
|
# Prompts
|
||||||
|
keyword_extraction_prompt="Extract domain-specific terms",
|
||||||
|
question_generation_prompt="Generate FAQ questions",
|
||||||
|
use_llm_question_generation=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.kb_routing_method == "llm_based"
|
||||||
|
assert len(config.kb_params) == 1
|
||||||
|
assert config.enable_metadata_similarity is True
|
||||||
|
assert config.use_llm_metadata_filter is True
|
||||||
|
assert config.enable_parent_child is True
|
||||||
|
assert config.use_summary_mapping is True
|
||||||
|
assert config.use_llm_question_generation is True
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
230
test/unit_test/services/test_metadata_service.py
Normal file
230
test/unit_test/services/test_metadata_service.py
Normal file
|
|
@ -0,0 +1,230 @@
|
||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for MetadataService.
|
||||||
|
|
||||||
|
Tests batch CRUD operations for document metadata management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceBatchGet:
|
||||||
|
"""Test batch_get_metadata functionality."""
|
||||||
|
|
||||||
|
def test_batch_get_empty_ids(self):
|
||||||
|
"""Test batch get with empty doc_ids returns empty dict."""
|
||||||
|
from api.db.services.metadata_service import MetadataService
|
||||||
|
|
||||||
|
with patch.object(MetadataService, 'batch_get_metadata', return_value={}):
|
||||||
|
result = MetadataService.batch_get_metadata([])
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_batch_get_with_fields_filter(self):
|
||||||
|
"""Test batch get filters to requested fields."""
|
||||||
|
# This tests the logic of field filtering
|
||||||
|
full_metadata = {"field1": "value1", "field2": "value2", "field3": "value3"}
|
||||||
|
requested_fields = ["field1", "field3"]
|
||||||
|
|
||||||
|
filtered = {k: v for k, v in full_metadata.items() if k in requested_fields}
|
||||||
|
|
||||||
|
assert "field1" in filtered
|
||||||
|
assert "field3" in filtered
|
||||||
|
assert "field2" not in filtered
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceBatchUpdate:
|
||||||
|
"""Test batch_update_metadata functionality."""
|
||||||
|
|
||||||
|
def test_update_merge_logic(self):
|
||||||
|
"""Test metadata merge logic."""
|
||||||
|
existing = {"field1": "old_value", "field2": "keep_this"}
|
||||||
|
new_metadata = {"field1": "new_value", "field3": "added"}
|
||||||
|
|
||||||
|
# Merge logic
|
||||||
|
existing.update(new_metadata)
|
||||||
|
|
||||||
|
assert existing["field1"] == "new_value" # Updated
|
||||||
|
assert existing["field2"] == "keep_this" # Preserved
|
||||||
|
assert existing["field3"] == "added" # Added
|
||||||
|
|
||||||
|
def test_update_replace_logic(self):
|
||||||
|
"""Test metadata replace logic."""
|
||||||
|
existing = {"field1": "old_value", "field2": "keep_this"}
|
||||||
|
new_metadata = {"field1": "new_value", "field3": "added"}
|
||||||
|
|
||||||
|
# Replace logic (don't merge)
|
||||||
|
result = new_metadata
|
||||||
|
|
||||||
|
assert result["field1"] == "new_value"
|
||||||
|
assert "field2" not in result # Not preserved
|
||||||
|
assert result["field3"] == "added"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceBatchDelete:
|
||||||
|
"""Test batch_delete_metadata_fields functionality."""
|
||||||
|
|
||||||
|
def test_delete_fields_logic(self):
|
||||||
|
"""Test field deletion logic."""
|
||||||
|
metadata = {"field1": "value1", "field2": "value2", "field3": "value3"}
|
||||||
|
fields_to_delete = ["field1", "field3"]
|
||||||
|
|
||||||
|
for field in fields_to_delete:
|
||||||
|
if field in metadata:
|
||||||
|
del metadata[field]
|
||||||
|
|
||||||
|
assert "field1" not in metadata
|
||||||
|
assert "field2" in metadata
|
||||||
|
assert "field3" not in metadata
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceSearch:
|
||||||
|
"""Test search_by_metadata functionality."""
|
||||||
|
|
||||||
|
def test_equals_filter(self):
|
||||||
|
"""Test equals filter logic."""
|
||||||
|
doc_value = "Technical"
|
||||||
|
condition = "Technical"
|
||||||
|
|
||||||
|
matches = str(doc_value) == str(condition)
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
def test_contains_filter(self):
|
||||||
|
"""Test contains filter logic."""
|
||||||
|
doc_value = "Technical Documentation"
|
||||||
|
condition = {"contains": "Technical"}
|
||||||
|
|
||||||
|
val = condition["contains"]
|
||||||
|
matches = str(val).lower() in str(doc_value).lower()
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
def test_starts_with_filter(self):
|
||||||
|
"""Test starts_with filter logic."""
|
||||||
|
doc_value = "Technical Documentation"
|
||||||
|
condition = {"starts_with": "Tech"}
|
||||||
|
|
||||||
|
val = condition["starts_with"]
|
||||||
|
matches = str(doc_value).lower().startswith(str(val).lower())
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
def test_gt_filter(self):
|
||||||
|
"""Test greater than filter logic."""
|
||||||
|
doc_value = 2023
|
||||||
|
condition = {"gt": 2020}
|
||||||
|
|
||||||
|
val = condition["gt"]
|
||||||
|
matches = float(doc_value) > float(val)
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
def test_lt_filter(self):
|
||||||
|
"""Test less than filter logic."""
|
||||||
|
doc_value = 2019
|
||||||
|
condition = {"lt": 2020}
|
||||||
|
|
||||||
|
val = condition["lt"]
|
||||||
|
matches = float(doc_value) < float(val)
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
def test_in_filter(self):
|
||||||
|
"""Test in filter logic."""
|
||||||
|
doc_value = "Technical"
|
||||||
|
condition = {"in": ["Technical", "Legal", "HR"]}
|
||||||
|
|
||||||
|
val = condition["in"]
|
||||||
|
matches = doc_value in val
|
||||||
|
assert matches is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceSchema:
|
||||||
|
"""Test get_metadata_schema functionality."""
|
||||||
|
|
||||||
|
def test_schema_type_detection(self):
|
||||||
|
"""Test type detection in schema."""
|
||||||
|
values = [
|
||||||
|
("string_value", "str"),
|
||||||
|
(123, "int"),
|
||||||
|
(12.5, "float"),
|
||||||
|
(True, "bool"),
|
||||||
|
(["a", "b"], "list"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for value, expected_type in values:
|
||||||
|
detected_type = type(value).__name__
|
||||||
|
assert detected_type == expected_type
|
||||||
|
|
||||||
|
def test_schema_sample_values_limit(self):
|
||||||
|
"""Test sample values are limited."""
|
||||||
|
sample_values = set()
|
||||||
|
max_samples = 10
|
||||||
|
|
||||||
|
for i in range(20):
|
||||||
|
if len(sample_values) < max_samples:
|
||||||
|
sample_values.add(f"value_{i}")
|
||||||
|
|
||||||
|
assert len(sample_values) == max_samples
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceStatistics:
|
||||||
|
"""Test get_metadata_statistics functionality."""
|
||||||
|
|
||||||
|
def test_coverage_calculation(self):
|
||||||
|
"""Test metadata coverage calculation."""
|
||||||
|
total_docs = 100
|
||||||
|
docs_with_metadata = 80
|
||||||
|
|
||||||
|
coverage = docs_with_metadata / total_docs if total_docs > 0 else 0
|
||||||
|
|
||||||
|
assert coverage == 0.8
|
||||||
|
|
||||||
|
def test_coverage_zero_docs(self):
|
||||||
|
"""Test coverage with zero documents."""
|
||||||
|
total_docs = 0
|
||||||
|
docs_with_metadata = 0
|
||||||
|
|
||||||
|
coverage = docs_with_metadata / total_docs if total_docs > 0 else 0
|
||||||
|
|
||||||
|
assert coverage == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataServiceCopy:
|
||||||
|
"""Test copy_metadata functionality."""
|
||||||
|
|
||||||
|
def test_copy_all_fields(self):
|
||||||
|
"""Test copying all metadata fields."""
|
||||||
|
source_meta = {"field1": "value1", "field2": "value2"}
|
||||||
|
|
||||||
|
# Copy all
|
||||||
|
copied = source_meta.copy()
|
||||||
|
|
||||||
|
assert copied == source_meta
|
||||||
|
assert copied is not source_meta # Different object
|
||||||
|
|
||||||
|
def test_copy_specific_fields(self):
|
||||||
|
"""Test copying specific metadata fields."""
|
||||||
|
source_meta = {"field1": "value1", "field2": "value2", "field3": "value3"}
|
||||||
|
fields = ["field1", "field3"]
|
||||||
|
|
||||||
|
copied = {k: v for k, v in source_meta.items() if k in fields}
|
||||||
|
|
||||||
|
assert "field1" in copied
|
||||||
|
assert "field2" not in copied
|
||||||
|
assert "field3" in copied
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
Loading…
Add table
Reference in a new issue