1091 lines
44 KiB
Python
1091 lines
44 KiB
Python
#
|
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""
|
|
Query Decomposition Retrieval Component
|
|
|
|
This module implements an advanced retrieval system that automatically decomposes complex queries
|
|
into simpler sub-queries, performs concurrent retrieval, and intelligently reranks results using
|
|
LLM-based scoring combined with vector similarity.
|
|
|
|
Key Features:
|
|
- Automatic query decomposition for complex, multi-faceted questions
|
|
- Concurrent retrieval across multiple sub-queries for better performance
|
|
- Global chunk deduplication to eliminate redundant results
|
|
- LLM-based relevance scoring for each chunk
|
|
- Configurable score fusion between vector similarity and LLM scores
|
|
- Built-in high-quality default prompts
|
|
|
|
Use Cases:
|
|
- Complex queries with multiple aspects ("Compare A and B", "Explain X, Y, and Z")
|
|
- Multi-hop reasoning questions
|
|
- Research queries requiring comprehensive coverage
|
|
- Questions needing information from multiple sources
|
|
|
|
Advantages over Manual Workflow Approach:
|
|
- Simplified configuration (no need to manually assemble components)
|
|
- Better performance (internal concurrency, minimal overhead)
|
|
- Higher quality results (global deduplication and reranking)
|
|
- Deterministic behavior (no agent unpredictability)
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from abc import ABC
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
|
|
from agent.tools.base import ToolBase, ToolMeta, ToolParamBase
|
|
from api.db.services.document_service import DocumentService
|
|
from api.db.services.dialog_service import meta_filter
|
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
from api.db.services.llm_service import LLMBundle
|
|
from common import settings
|
|
from common.connection_utils import timeout
|
|
from common.constants import LLMType
|
|
from rag.app.tag import label_question
|
|
from rag.prompts.generator import cross_languages, gen_meta_filter, kb_prompt
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Default Prompts
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Default prompt for query decomposition
|
|
# This prompt instructs the LLM to break down complex queries into simpler sub-questions
|
|
DEFAULT_DECOMPOSITION_PROMPT = """You are a query decomposition expert. Your task is to break down complex questions into 2 to 3 simpler, independently retrievable sub-questions.
|
|
|
|
Guidelines:
|
|
1. Each sub-question should focus on one specific aspect of the original query
|
|
2. Sub-questions should be non-redundant (no overlap in information sought)
|
|
3. Together, the sub-questions should cover all key aspects of the original query
|
|
4. Each sub-question should be clear, specific, and independently answerable
|
|
5. Avoid creating more than {max_count} sub-questions
|
|
|
|
**Original Query:** {original_query}
|
|
|
|
**Output Requirement:** Output ONLY a standard JSON array where each element is a string representing a sub-question.
|
|
Example format: ["What is concept A?", "What is concept B?", "How do A and B compare?"]
|
|
|
|
Do not output any explanatory text, commentary, or formatting other than the JSON array.
|
|
|
|
Sub-questions:"""
|
|
|
|
# Default prompt for LLM-based chunk relevance scoring
|
|
# This prompt instructs the LLM to judge how useful a chunk is for answering a query
|
|
DEFAULT_RERANKING_PROMPT = """You are an information relevance assessment expert. Your task is to judge how useful a given document chunk is for answering a specific query.
|
|
|
|
**Query:** {query}
|
|
|
|
**Document Chunk:**
|
|
{chunk_text}
|
|
|
|
**Assessment Task:**
|
|
1. **Relevance Score**: Provide an integer score from 1 to 10 based on the following criteria:
|
|
- 9-10: Contains direct, complete answer to the query
|
|
- 7-8: Contains substantial relevant information that partially answers the query
|
|
- 5-6: Contains indirect clues or related context that could help answer the query
|
|
- 3-4: Tangentially related but not directly useful
|
|
- 1-2: Completely irrelevant to the query
|
|
|
|
2. **Brief Justification**: In ONE sentence, explain the core reason for your score.
|
|
|
|
**Output Requirement:** Output STRICTLY in JSON format with no additional text:
|
|
{{"score": <integer 1-10>, "reason": "<one sentence justification>"}}
|
|
|
|
Assessment:"""
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Parameter Configuration Class
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class QueryDecompositionRetrievalParam(ToolParamBase):
|
|
"""
|
|
Configuration parameters for Query Decomposition Retrieval component.
|
|
|
|
This class defines all configurable parameters for the advanced retrieval system,
|
|
including query decomposition settings, reranking configuration, and score fusion weights.
|
|
|
|
Attributes:
|
|
enable_decomposition (bool): Master toggle for the entire decomposition pipeline
|
|
decomposition_prompt (str): Custom or default prompt for query decomposition
|
|
reranking_prompt (str): Custom or default prompt for LLM-based chunk scoring
|
|
score_fusion_weight (float): Weight for LLM score vs vector similarity (0.0-1.0)
|
|
max_decomposition_count (int): Maximum number of sub-queries to generate
|
|
enable_concurrency (bool): Whether to retrieve sub-queries concurrently
|
|
similarity_threshold (float): Minimum similarity score for chunk inclusion
|
|
keywords_similarity_weight (float): Weight of keyword matching vs vector similarity
|
|
top_n (int): Number of final chunks to return
|
|
top_k (int): Number of initial candidates to retrieve per sub-query
|
|
kb_ids (list): Knowledge base IDs to search
|
|
rerank_id (str): Reranking model ID (for traditional reranking if needed)
|
|
empty_response (str): Response when no results found
|
|
use_kg (bool): Whether to use knowledge graph retrieval
|
|
cross_languages (list): Languages for cross-lingual retrieval
|
|
toc_enhance (bool): Whether to enhance with table-of-contents
|
|
meta_data_filter (dict): Metadata filters for retrieval
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize query decomposition retrieval parameters with sensible defaults."""
|
|
# Define tool metadata for agent integration
|
|
self.meta: ToolMeta = {
|
|
"name": "advanced_search_with_decomposition",
|
|
"description": (
|
|
"Advanced retrieval tool that automatically decomposes complex queries into "
|
|
"simpler sub-questions, performs concurrent retrieval, and intelligently "
|
|
"reranks results using LLM-based scoring."
|
|
),
|
|
"parameters": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": (
|
|
"The complex query to search for. Can be multi-faceted or require "
|
|
"information from multiple sources."
|
|
),
|
|
"default": "",
|
|
"required": True
|
|
}
|
|
}
|
|
}
|
|
|
|
super().__init__()
|
|
|
|
# Tool identification
|
|
self.function_name = "advanced_search_with_decomposition"
|
|
self.description = self.meta["description"]
|
|
|
|
# Query Decomposition Settings
|
|
# These control how complex queries are broken down into sub-questions
|
|
self.enable_decomposition = True # Master toggle for decomposition feature
|
|
self.decomposition_prompt = DEFAULT_DECOMPOSITION_PROMPT # LLM prompt for decomposition
|
|
self.max_decomposition_count = 3 # Limit sub-queries to prevent over-decomposition
|
|
|
|
# Reranking & Scoring Settings
|
|
# These control how chunks are scored and ranked after retrieval
|
|
self.reranking_prompt = DEFAULT_RERANKING_PROMPT # LLM prompt for chunk scoring
|
|
self.score_fusion_weight = 0.7 # Weight: 0.7*LLM_score + 0.3*vector_score
|
|
|
|
# Concurrency Settings
|
|
# Controls whether sub-queries are processed in parallel
|
|
self.enable_concurrency = True # Enable parallel retrieval for better performance
|
|
|
|
# Traditional Retrieval Settings
|
|
# These are inherited from the base Retrieval component
|
|
self.similarity_threshold = 0.2 # Minimum similarity score to include a chunk
|
|
self.keywords_similarity_weight = 0.3 # Weight for keyword vs vector matching
|
|
self.top_n = 8 # Number of final results to return
|
|
self.top_k = 1024 # Number of initial candidates to retrieve
|
|
|
|
# Knowledge Base Settings
|
|
self.kb_ids = [] # List of knowledge base IDs to search
|
|
self.kb_vars = [] # Knowledge base variables for dynamic selection
|
|
self.rerank_id = "" # Traditional reranking model ID
|
|
self.empty_response = "No relevant information found." # Default empty response
|
|
|
|
# Advanced Features
|
|
self.use_kg = False # Whether to use knowledge graph retrieval
|
|
self.cross_languages = [] # Languages for cross-lingual search
|
|
self.toc_enhance = False # Whether to enhance with document TOC
|
|
self.meta_data_filter = {} # Metadata filtering criteria
|
|
|
|
def check(self):
|
|
"""
|
|
Validate parameter values to ensure they are within acceptable ranges.
|
|
|
|
This method is called before execution to catch configuration errors early.
|
|
It checks that all numerical parameters are within valid ranges and that
|
|
required parameters are properly set.
|
|
"""
|
|
# Validate similarity thresholds (must be between 0.0 and 1.0)
|
|
self.check_decimal_float(
|
|
self.similarity_threshold,
|
|
"[QueryDecompositionRetrieval] Similarity threshold"
|
|
)
|
|
self.check_decimal_float(
|
|
self.keywords_similarity_weight,
|
|
"[QueryDecompositionRetrieval] Keyword similarity weight"
|
|
)
|
|
self.check_decimal_float(
|
|
self.score_fusion_weight,
|
|
"[QueryDecompositionRetrieval] Score fusion weight"
|
|
)
|
|
|
|
# Validate positive integers
|
|
self.check_positive_number(
|
|
self.top_n,
|
|
"[QueryDecompositionRetrieval] Top N"
|
|
)
|
|
self.check_positive_number(
|
|
self.max_decomposition_count,
|
|
"[QueryDecompositionRetrieval] Max decomposition count"
|
|
)
|
|
|
|
# Ensure max_decomposition_count is reasonable (1-10)
|
|
if not (1 <= self.max_decomposition_count <= 10):
|
|
raise ValueError(
|
|
f"Max decomposition count must be between 1 and 10, got {self.max_decomposition_count}"
|
|
)
|
|
|
|
# Ensure score_fusion_weight is between 0 and 1
|
|
if not (0.0 <= self.score_fusion_weight <= 1.0):
|
|
raise ValueError(
|
|
f"Score fusion weight must be between 0.0 and 1.0, got {self.score_fusion_weight}"
|
|
)
|
|
|
|
def get_input_form(self) -> Dict[str, Dict]:
|
|
"""
|
|
Define the input form structure for UI rendering.
|
|
|
|
Returns:
|
|
Dictionary defining the input fields and their types for the UI
|
|
"""
|
|
return {
|
|
"query": {
|
|
"name": "Query",
|
|
"type": "line"
|
|
}
|
|
}
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Query Decomposition Retrieval Component
|
|
# -----------------------------------------------------------------------------
|
|
|
|
class QueryDecompositionRetrieval(ToolBase, ABC):
|
|
"""
|
|
Advanced retrieval component with automatic query decomposition and intelligent reranking.
|
|
|
|
This component implements a sophisticated retrieval pipeline that:
|
|
1. Analyzes the input query for complexity
|
|
2. Decomposes complex queries into simpler sub-questions using LLM
|
|
3. Performs concurrent retrieval for all sub-queries
|
|
4. Deduplicates chunks across all sub-query results
|
|
5. Uses LLM to score each unique chunk's relevance
|
|
6. Fuses LLM scores with vector similarity scores
|
|
7. Returns globally ranked, deduplicated results
|
|
|
|
This approach delivers better results than manual workflow assembly or agent-based
|
|
approaches while maintaining high performance and deterministic behavior.
|
|
"""
|
|
|
|
component_name = "QueryDecompositionRetrieval"
|
|
|
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) # Longer timeout for complex processing
|
|
def _invoke(self, **kwargs):
|
|
"""
|
|
Main execution method for the query decomposition retrieval component.
|
|
|
|
This method orchestrates the entire retrieval pipeline from query decomposition
|
|
through final result ranking.
|
|
|
|
Args:
|
|
**kwargs: Keyword arguments including:
|
|
- query (str): The user's input query
|
|
|
|
Returns:
|
|
str: Formatted content for downstream components
|
|
|
|
Side Effects:
|
|
- Sets output variables: "formalized_content" and "json"
|
|
- Adds references to the canvas for citation tracking
|
|
"""
|
|
# Check if component execution was canceled by user
|
|
if self.check_if_canceled("Query decomposition retrieval processing"):
|
|
return
|
|
|
|
# Extract and validate query
|
|
query = kwargs.get("query", "").strip()
|
|
if not query:
|
|
# No query provided - return empty response
|
|
logging.warning("No query provided to QueryDecompositionRetrieval")
|
|
self.set_output("formalized_content", self._param.empty_response)
|
|
self.set_output("json", [])
|
|
return
|
|
|
|
try:
|
|
# Step 1: Resolve and validate knowledge bases
|
|
kb_ids, kbs, embd_mdl, rerank_mdl = self._prepare_knowledge_bases()
|
|
|
|
if not kbs:
|
|
raise Exception("No valid knowledge bases found")
|
|
|
|
# Step 2: Process query variables and format the query string
|
|
query = self._process_query_variables(query)
|
|
|
|
# Step 3: Determine whether to use decomposition
|
|
# If decomposition is disabled or query is simple, use direct retrieval
|
|
if not self._param.enable_decomposition:
|
|
logging.info("Query decomposition is disabled - using direct retrieval")
|
|
results = self._direct_retrieval(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
kbs=kbs,
|
|
embd_mdl=embd_mdl,
|
|
rerank_mdl=rerank_mdl
|
|
)
|
|
else:
|
|
# Step 4: Decompose query into sub-questions
|
|
sub_queries = self._decompose_query(query)
|
|
|
|
# If decomposition failed or returned only one query, fall back to direct retrieval
|
|
if len(sub_queries) <= 1:
|
|
logging.info(f"Query decomposition produced {len(sub_queries)} sub-queries - using direct retrieval")
|
|
results = self._direct_retrieval(
|
|
query=query,
|
|
kb_ids=kb_ids,
|
|
kbs=kbs,
|
|
embd_mdl=embd_mdl,
|
|
rerank_mdl=rerank_mdl
|
|
)
|
|
else:
|
|
logging.info(f"Query decomposed into {len(sub_queries)} sub-queries: {sub_queries}")
|
|
|
|
# Step 5: Retrieve chunks for all sub-queries
|
|
all_chunks = self._concurrent_retrieval(
|
|
sub_queries=sub_queries,
|
|
kb_ids=kb_ids,
|
|
kbs=kbs,
|
|
embd_mdl=embd_mdl,
|
|
rerank_mdl=rerank_mdl
|
|
)
|
|
|
|
# Step 6: Deduplicate and globally rerank
|
|
results = self._global_rerank_and_deduplicate(
|
|
chunks_by_query=all_chunks,
|
|
sub_queries=sub_queries,
|
|
original_query=query
|
|
)
|
|
|
|
# Step 7: Format and return results
|
|
self._format_and_set_output(results)
|
|
|
|
logging.info(f"Query decomposition retrieval completed successfully with {len(results)} results")
|
|
|
|
except Exception as e:
|
|
# Log the error and return empty response
|
|
logging.exception(f"Error in QueryDecompositionRetrieval: {str(e)}")
|
|
self.set_output("formalized_content", self._param.empty_response)
|
|
self.set_output("json", [])
|
|
raise
|
|
|
|
def _prepare_knowledge_bases(self) -> Tuple:
|
|
"""
|
|
Resolve knowledge base IDs and prepare embedding/reranking models.
|
|
|
|
This method:
|
|
1. Resolves knowledge base IDs (including variable references)
|
|
2. Validates that all KBs exist and are accessible
|
|
3. Ensures all KBs use the same embedding model
|
|
4. Initializes embedding and reranking model bundles
|
|
|
|
Returns:
|
|
Tuple of (kb_ids, kbs, embd_mdl, rerank_mdl)
|
|
|
|
Raises:
|
|
Exception: If no valid knowledge bases found or if KBs use different embeddings
|
|
"""
|
|
logging.info("Preparing knowledge bases for retrieval")
|
|
|
|
# Resolve knowledge base IDs (may include variable references like "@kb_var")
|
|
kb_ids: List[str] = []
|
|
for id in self._param.kb_ids:
|
|
if id.find("@") < 0:
|
|
# Direct KB ID reference
|
|
kb_ids.append(id)
|
|
continue
|
|
|
|
# Variable reference - resolve the actual KB name/ID
|
|
kb_nm = self._canvas.get_variable_value(id)
|
|
kb_nm_list = kb_nm if isinstance(kb_nm, list) else [kb_nm]
|
|
|
|
for nm_or_id in kb_nm_list:
|
|
# Try to find KB by name first, then by ID
|
|
e, kb = KnowledgebaseService.get_by_name(
|
|
nm_or_id,
|
|
self._canvas._tenant_id
|
|
)
|
|
if not e:
|
|
e, kb = KnowledgebaseService.get_by_id(nm_or_id)
|
|
if not e:
|
|
raise Exception(f"Knowledge base ({nm_or_id}) does not exist")
|
|
kb_ids.append(kb.id)
|
|
|
|
# Remove duplicates and filter empty IDs
|
|
kb_ids = list(set([kb_id for kb_id in kb_ids if kb_id]))
|
|
|
|
if not kb_ids:
|
|
raise Exception("No knowledge base IDs provided")
|
|
|
|
# Retrieve knowledge base objects
|
|
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
|
if not kbs:
|
|
raise Exception("No valid knowledge bases found")
|
|
|
|
# Verify all KBs use the same embedding model
|
|
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
|
if len(embd_nms) > 1:
|
|
raise Exception(
|
|
f"Knowledge bases use different embedding models: {embd_nms}. "
|
|
"All KBs must use the same embedding model for consistent retrieval."
|
|
)
|
|
|
|
# Initialize embedding model bundle
|
|
embd_mdl = None
|
|
if embd_nms and embd_nms[0]:
|
|
embd_mdl = LLMBundle(
|
|
self._canvas.get_tenant_id(),
|
|
LLMType.EMBEDDING,
|
|
embd_nms[0]
|
|
)
|
|
|
|
# Initialize reranking model bundle if specified
|
|
rerank_mdl = None
|
|
if self._param.rerank_id:
|
|
rerank_mdl = LLMBundle(
|
|
kbs[0].tenant_id,
|
|
LLMType.RERANK,
|
|
self._param.rerank_id
|
|
)
|
|
|
|
logging.info(f"Prepared {len(kbs)} knowledge bases with embedding model: {embd_nms[0] if embd_nms else 'None'}")
|
|
|
|
return kb_ids, kbs, embd_mdl, rerank_mdl
|
|
|
|
def _process_query_variables(self, query: str) -> str:
|
|
"""
|
|
Process and substitute variables in the query string.
|
|
|
|
Queries may contain variable references (e.g., "{user_name}") that need to be
|
|
replaced with actual values from the canvas context.
|
|
|
|
Args:
|
|
query (str): Query string potentially containing variable references
|
|
|
|
Returns:
|
|
str: Query with all variables substituted
|
|
"""
|
|
# Extract variable references from query text
|
|
vars = self.get_input_elements_from_text(query)
|
|
vars = {k: o["value"] for k, o in vars.items()}
|
|
|
|
# Substitute variables into query
|
|
query = self.string_format(query, vars)
|
|
|
|
return query.strip()
|
|
|
|
def _decompose_query(self, query: str) -> List[str]:
|
|
"""
|
|
Decompose a complex query into simpler sub-questions using LLM.
|
|
|
|
This method calls the LLM with the decomposition prompt to break down
|
|
a complex query into 2-3 simpler, independently answerable sub-questions.
|
|
|
|
Args:
|
|
query (str): The original complex query
|
|
|
|
Returns:
|
|
List[str]: List of sub-questions (returns [query] if decomposition fails)
|
|
"""
|
|
logging.info(f"Decomposing query: {query}")
|
|
|
|
try:
|
|
# Get LLM for query decomposition
|
|
llm = LLMBundle(
|
|
self._canvas.get_tenant_id(),
|
|
LLMType.CHAT
|
|
)
|
|
|
|
# Format the decomposition prompt with the actual query
|
|
prompt = self._param.decomposition_prompt.format(
|
|
original_query=query,
|
|
max_count=self._param.max_decomposition_count
|
|
)
|
|
|
|
# Call LLM to decompose query
|
|
# We use a system message to set context and a user message with the prompt
|
|
response = llm.chat(
|
|
system="You are a helpful assistant that decomposes complex queries into simpler sub-questions.",
|
|
messages=[{"role": "user", "content": prompt}],
|
|
gen_conf={"temperature": 0.1, "max_tokens": 500} # Low temperature for consistency
|
|
)
|
|
|
|
# Extract the response text
|
|
response_text = response.strip()
|
|
|
|
logging.debug(f"LLM decomposition response: {response_text}")
|
|
|
|
# Parse JSON array from response
|
|
# The LLM should return a JSON array like: ["question 1", "question 2"]
|
|
sub_queries = self._parse_sub_queries_from_response(response_text)
|
|
|
|
# Validate sub-queries
|
|
if not sub_queries:
|
|
logging.warning("Query decomposition produced no sub-queries - falling back to original query")
|
|
return [query]
|
|
|
|
# Limit to max_decomposition_count
|
|
sub_queries = sub_queries[:self._param.max_decomposition_count]
|
|
|
|
logging.info(f"Successfully decomposed query into {len(sub_queries)} sub-queries")
|
|
return sub_queries
|
|
|
|
except Exception as e:
|
|
# If decomposition fails for any reason, fall back to using the original query
|
|
logging.exception(f"Query decomposition failed: {str(e)}")
|
|
logging.info("Falling back to original query")
|
|
return [query]
|
|
|
|
def _parse_sub_queries_from_response(self, response_text: str) -> List[str]:
|
|
"""
|
|
Parse sub-queries from LLM response text.
|
|
|
|
The LLM should return a JSON array, but may include extra text or formatting.
|
|
This method extracts the JSON array and parses it robustly.
|
|
|
|
Args:
|
|
response_text (str): Raw text response from LLM
|
|
|
|
Returns:
|
|
List[str]: Parsed list of sub-questions
|
|
"""
|
|
# Try to find JSON array in response
|
|
# Look for patterns like: ["query1", "query2"] or ['query1', 'query2']
|
|
|
|
# First, try direct JSON parsing
|
|
try:
|
|
sub_queries = json.loads(response_text)
|
|
if isinstance(sub_queries, list) and all(isinstance(q, str) for q in sub_queries):
|
|
return [q.strip() for q in sub_queries if q.strip()]
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try to extract JSON array from response using regex
|
|
json_array_pattern = r'\[(?:[^\[\]]|"(?:[^"\\]|\\.)*")+\]'
|
|
matches = re.findall(json_array_pattern, response_text, re.DOTALL)
|
|
|
|
for match in matches:
|
|
try:
|
|
sub_queries = json.loads(match)
|
|
if isinstance(sub_queries, list) and all(isinstance(q, str) for q in sub_queries):
|
|
return [q.strip() for q in sub_queries if q.strip()]
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# If JSON parsing fails, try to extract quoted strings
|
|
# Look for strings in quotes: "question 1", "question 2"
|
|
quoted_pattern = r'"([^"]+)"|\'([^\']+)\''
|
|
matches = re.findall(quoted_pattern, response_text)
|
|
if matches:
|
|
sub_queries = [m[0] or m[1] for m in matches]
|
|
sub_queries = [q.strip() for q in sub_queries if q.strip()]
|
|
if sub_queries:
|
|
return sub_queries
|
|
|
|
# If all parsing attempts fail, return empty list
|
|
logging.warning(f"Failed to parse sub-queries from LLM response: {response_text}")
|
|
return []
|
|
|
|
def _direct_retrieval(
|
|
self,
|
|
query: str,
|
|
kb_ids: List[str],
|
|
kbs: List,
|
|
embd_mdl,
|
|
rerank_mdl
|
|
) -> List[Dict]:
|
|
"""
|
|
Perform direct retrieval without query decomposition.
|
|
|
|
This is a fallback method used when:
|
|
- Query decomposition is disabled
|
|
- Query decomposition produces only one sub-query
|
|
- Query decomposition fails
|
|
|
|
Args:
|
|
query (str): Query to search for
|
|
kb_ids (List[str]): Knowledge base IDs
|
|
kbs (List): Knowledge base objects
|
|
embd_mdl: Embedding model bundle
|
|
rerank_mdl: Reranking model bundle
|
|
|
|
Returns:
|
|
List[Dict]: Retrieved and ranked chunks
|
|
"""
|
|
logging.info(f"Performing direct retrieval for query: {query}")
|
|
|
|
# Prepare document IDs if metadata filtering is enabled
|
|
doc_ids = []
|
|
if self._param.meta_data_filter:
|
|
doc_ids = meta_filter(
|
|
self._param.meta_data_filter,
|
|
kb_ids
|
|
)
|
|
if not doc_ids:
|
|
logging.warning("Metadata filtering returned no matching documents")
|
|
return []
|
|
|
|
# Handle cross-language retrieval if enabled
|
|
# This translates the query into multiple languages for broader search
|
|
if self._param.cross_languages:
|
|
trans_queries = {}
|
|
for lang in self._param.cross_languages:
|
|
trans_q, _ = cross_languages(
|
|
[query],
|
|
[lang],
|
|
LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
|
)
|
|
trans_queries[lang] = trans_q
|
|
|
|
# Perform vector retrieval using the standard retrieval engine
|
|
kbinfos = settings.retriever.retrieval(
|
|
question=query,
|
|
embd_mdl=embd_mdl,
|
|
tenant_ids=[kb.tenant_id for kb in kbs],
|
|
kb_ids=kb_ids,
|
|
page=1,
|
|
page_size=self._param.top_n,
|
|
similarity_threshold=self._param.similarity_threshold,
|
|
vector_similarity_weight=1 - self._param.keywords_similarity_weight,
|
|
top=self._param.top_k,
|
|
doc_ids=doc_ids if doc_ids else None,
|
|
aggs=True,
|
|
rerank_mdl=rerank_mdl
|
|
)
|
|
|
|
# Check if retrieval was canceled
|
|
if self.check_if_canceled("Direct retrieval processing"):
|
|
return []
|
|
|
|
# Handle knowledge graph retrieval if enabled
|
|
if self._param.use_kg and kbs:
|
|
ck = settings.kg_retriever.retrieval(
|
|
query,
|
|
[kb.tenant_id for kb in kbs],
|
|
kb_ids,
|
|
embd_mdl,
|
|
LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
|
)
|
|
if self.check_if_canceled("KG retrieval processing"):
|
|
return []
|
|
if ck.get("content_with_weight"):
|
|
kbinfos["chunks"].insert(0, ck)
|
|
|
|
return kbinfos.get("chunks", [])
|
|
|
|
def _concurrent_retrieval(
|
|
self,
|
|
sub_queries: List[str],
|
|
kb_ids: List[str],
|
|
kbs: List,
|
|
embd_mdl,
|
|
rerank_mdl
|
|
) -> Dict[str, List[Dict]]:
|
|
"""
|
|
Perform concurrent retrieval for multiple sub-queries.
|
|
|
|
This method retrieves chunks for all sub-queries either sequentially or
|
|
in parallel (depending on enable_concurrency setting). Concurrent execution
|
|
significantly improves performance for multiple sub-queries.
|
|
|
|
Args:
|
|
sub_queries (List[str]): List of sub-questions to retrieve for
|
|
kb_ids (List[str]): Knowledge base IDs
|
|
kbs (List): Knowledge base objects
|
|
embd_mdl: Embedding model bundle
|
|
rerank_mdl: Reranking model bundle
|
|
|
|
Returns:
|
|
Dict[str, List[Dict]]: Mapping of sub-query -> retrieved chunks
|
|
"""
|
|
logging.info(f"Performing concurrent retrieval for {len(sub_queries)} sub-queries")
|
|
|
|
chunks_by_query = {}
|
|
|
|
if self._param.enable_concurrency and len(sub_queries) > 1:
|
|
# Concurrent execution for better performance
|
|
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
|
# Submit all retrieval tasks
|
|
future_to_query = {
|
|
executor.submit(
|
|
self._retrieve_for_sub_query,
|
|
sub_query,
|
|
kb_ids,
|
|
kbs,
|
|
embd_mdl,
|
|
rerank_mdl
|
|
): sub_query
|
|
for sub_query in sub_queries
|
|
}
|
|
|
|
# Collect results as they complete
|
|
for future in as_completed(future_to_query):
|
|
sub_query = future_to_query[future]
|
|
try:
|
|
chunks = future.result()
|
|
chunks_by_query[sub_query] = chunks
|
|
logging.info(f"Retrieved {len(chunks)} chunks for sub-query: {sub_query}")
|
|
except Exception as e:
|
|
logging.exception(f"Retrieval failed for sub-query '{sub_query}': {str(e)}")
|
|
chunks_by_query[sub_query] = []
|
|
else:
|
|
# Sequential execution
|
|
for sub_query in sub_queries:
|
|
try:
|
|
chunks = self._retrieve_for_sub_query(
|
|
sub_query,
|
|
kb_ids,
|
|
kbs,
|
|
embd_mdl,
|
|
rerank_mdl
|
|
)
|
|
chunks_by_query[sub_query] = chunks
|
|
logging.info(f"Retrieved {len(chunks)} chunks for sub-query: {sub_query}")
|
|
except Exception as e:
|
|
logging.exception(f"Retrieval failed for sub-query '{sub_query}': {str(e)}")
|
|
chunks_by_query[sub_query] = []
|
|
|
|
return chunks_by_query
|
|
|
|
def _retrieve_for_sub_query(
|
|
self,
|
|
sub_query: str,
|
|
kb_ids: List[str],
|
|
kbs: List,
|
|
embd_mdl,
|
|
rerank_mdl
|
|
) -> List[Dict]:
|
|
"""
|
|
Retrieve chunks for a single sub-query.
|
|
|
|
This method is called for each sub-query (either sequentially or in parallel).
|
|
It performs standard vector retrieval and returns the top-k candidates.
|
|
|
|
Args:
|
|
sub_query (str): Single sub-question to retrieve for
|
|
kb_ids (List[str]): Knowledge base IDs
|
|
kbs (List): Knowledge base objects
|
|
embd_mdl: Embedding model bundle
|
|
rerank_mdl: Reranking model bundle
|
|
|
|
Returns:
|
|
List[Dict]: Retrieved chunks for this sub-query
|
|
"""
|
|
# Use higher top_k for sub-queries to ensure good coverage
|
|
# We'll deduplicate and rerank globally later
|
|
top_k = min(self._param.top_k, 2048) # Reasonable upper limit
|
|
|
|
# Perform vector retrieval
|
|
kbinfos = settings.retriever.retrieval(
|
|
question=sub_query,
|
|
embd_mdl=embd_mdl,
|
|
tenant_ids=[kb.tenant_id for kb in kbs],
|
|
kb_ids=kb_ids,
|
|
page=1,
|
|
page_size=self._param.top_n * 2, # Retrieve more for deduplication
|
|
similarity_threshold=self._param.similarity_threshold * 0.8, # Slightly lower threshold for sub-queries
|
|
vector_similarity_weight=1 - self._param.keywords_similarity_weight,
|
|
top=top_k,
|
|
doc_ids=None,
|
|
aggs=False, # Don't need document aggregations for sub-queries
|
|
rerank_mdl=None # We'll do global reranking later
|
|
)
|
|
|
|
chunks = kbinfos.get("chunks", [])
|
|
|
|
# Store the sub-query with each chunk for later LLM scoring
|
|
for chunk in chunks:
|
|
chunk["_sub_query"] = sub_query
|
|
|
|
return chunks
|
|
|
|
def _global_rerank_and_deduplicate(
|
|
self,
|
|
chunks_by_query: Dict[str, List[Dict]],
|
|
sub_queries: List[str],
|
|
original_query: str
|
|
) -> List[Dict]:
|
|
"""
|
|
Perform global deduplication and reranking of chunks from all sub-queries.
|
|
|
|
This is the core innovation of the query decomposition approach. Instead of
|
|
treating sub-query results separately, we:
|
|
1. Deduplicate chunks across all sub-queries (same chunk may appear multiple times)
|
|
2. Use LLM to score each unique chunk's relevance
|
|
3. Fuse LLM scores with original vector similarity scores
|
|
4. Return globally ranked, deduplicated results
|
|
|
|
Args:
|
|
chunks_by_query (Dict[str, List[Dict]]): Chunks retrieved for each sub-query
|
|
sub_queries (List[str]): List of sub-questions
|
|
original_query (str): The original complex query
|
|
|
|
Returns:
|
|
List[Dict]: Globally ranked and deduplicated chunks
|
|
"""
|
|
logging.info("Performing global deduplication and reranking")
|
|
|
|
# Step 1: Deduplicate chunks by chunk_id
|
|
# Keep track of which sub-queries retrieved each chunk
|
|
chunk_map = {} # chunk_id -> {chunk_data, sub_queries, vector_scores}
|
|
|
|
for sub_query, chunks in chunks_by_query.items():
|
|
for chunk in chunks:
|
|
chunk_id = chunk.get("chunk_id")
|
|
if not chunk_id:
|
|
continue
|
|
|
|
if chunk_id not in chunk_map:
|
|
# First time seeing this chunk
|
|
chunk_map[chunk_id] = {
|
|
"chunk": chunk,
|
|
"sub_queries": [sub_query],
|
|
"vector_scores": [chunk.get("similarity", 0.0)]
|
|
}
|
|
else:
|
|
# Chunk appeared in multiple sub-queries
|
|
chunk_map[chunk_id]["sub_queries"].append(sub_query)
|
|
chunk_map[chunk_id]["vector_scores"].append(chunk.get("similarity", 0.0))
|
|
|
|
if not chunk_map:
|
|
logging.warning("No chunks to rerank after deduplication")
|
|
return []
|
|
|
|
logging.info(f"Deduplicated {sum(len(chunks) for chunks in chunks_by_query.values())} chunks to {len(chunk_map)} unique chunks")
|
|
|
|
# Step 2: LLM-based scoring of each unique chunk
|
|
# Score each chunk against the original query (not sub-queries)
|
|
scored_chunks = []
|
|
|
|
for chunk_id, chunk_info in chunk_map.items():
|
|
chunk = chunk_info["chunk"]
|
|
|
|
# Get LLM relevance score
|
|
llm_score = self._score_chunk_with_llm(
|
|
chunk=chunk,
|
|
query=original_query # Score against original query for global relevance
|
|
)
|
|
|
|
# Get average vector similarity score across all sub-queries that retrieved this chunk
|
|
avg_vector_score = np.mean(chunk_info["vector_scores"])
|
|
|
|
# Fuse LLM score with vector similarity score
|
|
# Final score = weight * LLM_score + (1-weight) * vector_score
|
|
final_score = (
|
|
self._param.score_fusion_weight * llm_score +
|
|
(1 - self._param.score_fusion_weight) * avg_vector_score
|
|
)
|
|
|
|
# Store scores in chunk for transparency
|
|
chunk["llm_relevance_score"] = llm_score
|
|
chunk["vector_similarity_score"] = avg_vector_score
|
|
chunk["final_fused_score"] = final_score
|
|
chunk["retrieved_by_sub_queries"] = chunk_info["sub_queries"]
|
|
|
|
scored_chunks.append(chunk)
|
|
|
|
# Step 3: Sort by final fused score
|
|
scored_chunks.sort(key=lambda x: x.get("final_fused_score", 0.0), reverse=True)
|
|
|
|
# Step 4: Return top N results
|
|
final_results = scored_chunks[:self._param.top_n]
|
|
|
|
logging.info(f"Global reranking complete - returning top {len(final_results)} chunks")
|
|
|
|
return final_results
|
|
|
|
def _score_chunk_with_llm(self, chunk: Dict, query: str) -> float:
|
|
"""
|
|
Score a chunk's relevance to a query using LLM.
|
|
|
|
This method calls the LLM with the reranking prompt to judge how useful
|
|
the chunk is for answering the query. The LLM returns a score from 1-10
|
|
which is then normalized to 0.0-1.0 range.
|
|
|
|
Args:
|
|
chunk (Dict): Chunk to score
|
|
query (str): Query to score against
|
|
|
|
Returns:
|
|
float: Normalized relevance score (0.0-1.0)
|
|
"""
|
|
try:
|
|
# Get LLM for scoring
|
|
llm = LLMBundle(
|
|
self._canvas.get_tenant_id(),
|
|
LLMType.CHAT
|
|
)
|
|
|
|
# Extract chunk text
|
|
chunk_text = chunk.get("content_with_weight", chunk.get("content_ltks", ""))
|
|
if not chunk_text:
|
|
logging.warning(f"Chunk {chunk.get('chunk_id')} has no content for scoring")
|
|
return 0.0
|
|
|
|
# Truncate chunk text if too long (to avoid token limits)
|
|
max_chunk_length = 2000
|
|
if len(chunk_text) > max_chunk_length:
|
|
chunk_text = chunk_text[:max_chunk_length] + "..."
|
|
|
|
# Format the reranking prompt
|
|
prompt = self._param.reranking_prompt.format(
|
|
query=query,
|
|
chunk_text=chunk_text
|
|
)
|
|
|
|
# Call LLM to score the chunk
|
|
response = llm.chat(
|
|
system="You are an expert at assessing information relevance.",
|
|
messages=[{"role": "user", "content": prompt}],
|
|
gen_conf={"temperature": 0.1, "max_tokens": 200}
|
|
)
|
|
|
|
response_text = response.strip()
|
|
|
|
# Parse JSON response: {"score": 8, "reason": "..."}
|
|
score_data = self._parse_score_from_response(response_text)
|
|
|
|
if score_data:
|
|
# Normalize score from 1-10 range to 0.0-1.0 range
|
|
raw_score = score_data.get("score", 5)
|
|
normalized_score = (raw_score - 1) / 9.0 # (score-1)/9 maps [1,10] to [0,1]
|
|
normalized_score = max(0.0, min(1.0, normalized_score)) # Clamp to [0,1]
|
|
|
|
logging.debug(f"LLM scored chunk {chunk.get('chunk_id')}: {raw_score}/10 (normalized: {normalized_score:.3f})")
|
|
|
|
return normalized_score
|
|
else:
|
|
# Failed to parse score - return neutral score
|
|
logging.warning(f"Failed to parse LLM score from response: {response_text}")
|
|
return 0.5
|
|
|
|
except Exception as e:
|
|
# If scoring fails, return neutral score
|
|
logging.exception(f"LLM scoring failed for chunk {chunk.get('chunk_id')}: {str(e)}")
|
|
return 0.5
|
|
|
|
def _parse_score_from_response(self, response_text: str) -> Dict:
|
|
"""
|
|
Parse score and reason from LLM response.
|
|
|
|
Expected format: {"score": 8, "reason": "Contains relevant information"}
|
|
|
|
Args:
|
|
response_text (str): Raw LLM response
|
|
|
|
Returns:
|
|
Dict with "score" and "reason", or None if parsing fails
|
|
"""
|
|
# Try direct JSON parsing
|
|
try:
|
|
data = json.loads(response_text)
|
|
if "score" in data:
|
|
return data
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Try to extract JSON object from response
|
|
json_pattern = r'\{[^}]+\}'
|
|
matches = re.findall(json_pattern, response_text)
|
|
|
|
for match in matches:
|
|
try:
|
|
data = json.loads(match)
|
|
if "score" in data:
|
|
return data
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Try to extract score using regex
|
|
# Look for patterns like: "score": 8 or score: 8
|
|
score_pattern = r'["\']?score["\']?\s*:\s*(\d+)'
|
|
match = re.search(score_pattern, response_text, re.IGNORECASE)
|
|
if match:
|
|
return {"score": int(match.group(1)), "reason": ""}
|
|
|
|
return None
|
|
|
|
def _format_and_set_output(self, chunks: List[Dict]):
|
|
"""
|
|
Format retrieved chunks and set output variables.
|
|
|
|
This method:
|
|
1. Cleans up internal fields from chunks
|
|
2. Formats chunks for JSON output
|
|
3. Adds references to canvas for citation tracking
|
|
4. Formats chunks as text for downstream components
|
|
5. Sets output variables
|
|
|
|
Args:
|
|
chunks (List[Dict]): Final ranked and deduplicated chunks
|
|
"""
|
|
if not chunks:
|
|
logging.info("No chunks to format - returning empty response")
|
|
self.set_output("formalized_content", self._param.empty_response)
|
|
self.set_output("json", [])
|
|
return
|
|
|
|
# Clean up internal fields that shouldn't be exposed
|
|
for chunk in chunks:
|
|
# Remove internal fields used during processing
|
|
chunk.pop("_sub_query", None)
|
|
chunk.pop("vector", None)
|
|
chunk.pop("content_ltks", None)
|
|
|
|
# Prepare JSON output
|
|
json_output = chunks.copy()
|
|
|
|
# Add references to canvas for citation tracking
|
|
# This allows the UI to show source documents
|
|
doc_aggs = [] # Document aggregations
|
|
self._canvas.add_reference(chunks, doc_aggs)
|
|
|
|
# Format chunks as text for downstream components
|
|
# This creates a formatted string with all chunk contents
|
|
formalized_content = "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": doc_aggs}, 200000, True))
|
|
|
|
# Set output variables
|
|
self.set_output("formalized_content", formalized_content)
|
|
self.set_output("json", json_output)
|
|
|
|
logging.info(f"Formatted {len(chunks)} chunks for output")
|
|
|
|
def thoughts(self) -> str:
|
|
"""
|
|
Return component thoughts for debugging/logging.
|
|
|
|
This method is called by the agent framework to get a description of
|
|
what the component is doing. Useful for debugging and monitoring.
|
|
|
|
Returns:
|
|
str: Description of component processing
|
|
"""
|
|
if self._param.enable_decomposition:
|
|
return (
|
|
f"Performing advanced retrieval with query decomposition. "
|
|
f"Will decompose complex queries into up to {self._param.max_decomposition_count} "
|
|
f"sub-questions and use LLM-based reranking with {self._param.score_fusion_weight} "
|
|
f"weight on LLM scores."
|
|
)
|
|
else:
|
|
return "Performing standard retrieval without query decomposition."
|
|
|