openrag/src/services/search_service.py
2025-10-10 22:14:51 -04:00

386 lines
16 KiB
Python

from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from auth_context import get_auth_context
from utils.logging_config import get_logger
logger = get_logger(__name__)
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool
async def search_tool(self, query: str, embedding_model: str = None) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
embedding_model (str): Optional override for embedding model.
If not provided, uses EMBED_MODEL from config.
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
from utils.embedding_fields import get_embedding_field_name
# Strategy: Use provided model, or default to EMBED_MODEL
# This assumes documents are embedded with EMBED_MODEL by default
# Future enhancement: Could auto-detect available models in corpus
embedding_model = embedding_model or EMBED_MODEL
embedding_field_name = get_embedding_field_name(embedding_model)
logger.info(
"Search with embedding model",
embedding_model=embedding_model,
embedding_field=embedding_field_name,
query_preview=query[:50] if query else None,
)
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Get search filters, limit, and score threshold from context
from auth_context import (
get_search_filters,
get_search_limit,
get_score_threshold,
)
filters = get_search_filters() or {}
limit = get_search_limit()
score_threshold = get_score_threshold()
# Detect wildcard request ("*") to return global facets/stats without semantic search
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
# Get available embedding models from corpus
query_embeddings = {}
available_models = []
if not is_wildcard_match_all:
# First, detect which embedding models exist in the corpus
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
# Build filter clauses first so we can use them in model detection
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
try:
# Build aggregation query with filters applied
agg_query = {
"size": 0,
"aggs": {
"embedding_models": {
"terms": {
"field": "embedding_model",
"size": 10
}
}
}
}
# Apply filters to model detection if any exist
if filter_clauses:
agg_query["query"] = {
"bool": {
"filter": filter_clauses
}
}
agg_result = await opensearch_client.search(index=INDEX_NAME, body=agg_query)
buckets = agg_result.get("aggregations", {}).get("embedding_models", {}).get("buckets", [])
available_models = [b["key"] for b in buckets if b["key"]]
if not available_models:
# Fallback to configured model if no documents indexed yet
available_models = [embedding_model]
logger.info(
"Detected embedding models in corpus",
available_models=available_models,
model_counts={b["key"]: b["doc_count"] for b in buckets},
with_filters=len(filter_clauses) > 0
)
except Exception as e:
logger.warning("Failed to detect embedding models, using configured model", error=str(e))
available_models = [embedding_model]
# Parallelize embedding generation for all models
import asyncio
async def embed_with_model(model_name):
try:
resp = await clients.patched_async_client.embeddings.create(
model=model_name, input=[query]
)
return model_name, resp.data[0].embedding
except Exception as e:
logger.error(f"Failed to embed with model {model_name}", error=str(e))
return model_name, None
# Run all embeddings in parallel
embedding_results = await asyncio.gather(
*[embed_with_model(model) for model in available_models],
return_exceptions=True
)
# Collect successful embeddings
for result in embedding_results:
if isinstance(result, tuple) and result[1] is not None:
model_name, embedding = result
query_embeddings[model_name] = embedding
logger.info(
"Generated query embeddings",
models=list(query_embeddings.keys()),
query_preview=query[:50]
)
else:
# Wildcard query - no embedding needed
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
else:
query_block = {"match_all": {}}
else:
# Build multi-model KNN queries
knn_queries = []
embedding_fields_to_check = []
for model_name, embedding_vector in query_embeddings.items():
field_name = get_embedding_field_name(model_name)
embedding_fields_to_check.append(field_name)
knn_queries.append({
"knn": {
field_name: {
"vector": embedding_vector,
"k": 50,
"num_candidates": 1000,
}
}
})
# Build exists filter - doc must have at least one embedding field
exists_any_embedding = {
"bool": {
"should": [{"exists": {"field": f}} for f in embedding_fields_to_check],
"minimum_should_match": 1
}
}
# Add exists filter to existing filters
all_filters = [*filter_clauses, exists_any_embedding]
logger.debug(
"Building hybrid query with filters",
user_filters_count=len(filter_clauses),
total_filters_count=len(all_filters),
filter_types=[type(f).__name__ for f in all_filters]
)
# Hybrid search query structure (semantic + keyword)
# Use dis_max to pick best score across multiple embedding fields
query_block = {
"bool": {
"should": [
{
"dis_max": {
"tie_breaker": 0.0, # Take only the best match, no blending
"boost": 0.7, # 70% weight for semantic search
"queries": knn_queries
}
},
{
"multi_match": {
"query": query,
"fields": ["text^2", "filename^1.5"],
"type": "best_fields",
"fuzziness": "AUTO",
"boost": 0.3, # 30% weight for keyword search
}
},
],
"minimum_should_match": 1,
"filter": all_filters,
}
}
search_body = {
"query": query_block,
"aggs": {
"data_sources": {"terms": {"field": "filename", "size": 20}},
"document_types": {"terms": {"field": "mimetype", "size": 10}},
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
"embedding_models": {"terms": {"field": "embedding_model.keyword", "size": 10}},
},
"_source": [
"filename",
"mimetype",
"page",
"text",
"source_url",
"owner",
"owner_name",
"owner_email",
"file_size",
"connector_type",
"embedding_model", # Include embedding model in results
"embedding_dimensions",
"allowed_users",
"allowed_groups",
],
"size": limit,
}
# Add score threshold only for hybrid (not meaningful for match_all)
if not is_wildcard_match_all and score_threshold > 0:
search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically
logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth through session manager
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e:
logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend
raise
# Transform results (keep for backward compatibility)
chunks = []
for hit in results["hits"]["hits"]:
chunks.append(
{
"filename": hit["_source"].get("filename"),
"mimetype": hit["_source"].get("mimetype"),
"page": hit["_source"].get("page"),
"text": hit["_source"].get("text"),
"score": hit.get("_score"),
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner"),
"owner_name": hit["_source"].get("owner_name"),
"owner_email": hit["_source"].get("owner_email"),
"file_size": hit["_source"].get("file_size"),
"connector_type": hit["_source"].get("connector_type"),
"embedding_model": hit["_source"].get("embedding_model"), # Include in results
"embedding_dimensions": hit["_source"].get("embedding_dimensions"),
}
)
# Return both transformed results and aggregations
return {
"results": chunks,
"aggregations": results.get("aggregations", {}),
"total": (
results.get("hits", {}).get("total", {}).get("value")
if isinstance(results.get("hits", {}).get("total"), dict)
else results.get("hits", {}).get("total")
),
}
async def search(
self,
query: str,
user_id: str = None,
jwt_token: str = None,
filters: Dict[str, Any] = None,
limit: int = 10,
score_threshold: float = 0,
embedding_model: str = None,
) -> Dict[str, Any]:
"""Public search method for API endpoints
Args:
embedding_model: Embedding model to use for search (defaults to EMBED_MODEL)
"""
# Set auth context if provided (for direct API calls)
from config.settings import is_no_auth_mode
if user_id and (jwt_token or is_no_auth_mode()):
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
# Set filters and limit in context if provided
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
return await self.search_tool(query, embedding_model=embedding_model)