openrag/src/services/search_service.py
2025-08-14 15:06:38 -04:00

171 lines
No EOL
7 KiB
Python

from typing import Any, Dict, Optional
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from auth_context import get_auth_context
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool
async def search_tool(self, query: str) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Get search filters, limit, and score threshold from context
from auth_context import get_search_filters, get_search_limit, get_score_threshold
filters = get_search_filters() or {}
limit = get_search_limit()
score_threshold = get_score_threshold()
# Detect wildcard request ("*") to return global facets/stats without semantic search
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
# Only embed when not doing match_all
if not is_wildcard_match_all:
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
query_embedding = resp.data[0].embedding
# Build filter clauses
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner"
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append({"term": {field_name: "__IMPOSSIBLE_VALUE__"}})
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
else:
query_block = {"match_all": {}}
else:
# Hybrid search query structure (semantic + keyword)
query_block = {
"bool": {
"should": [
{
"knn": {
"chunk_embedding": {
"vector": query_embedding,
"k": 10,
"boost": 0.7
}
}
},
{
"multi_match": {
"query": query,
"fields": ["text^2", "filename^1.5"],
"type": "best_fields",
"fuzziness": "AUTO",
"boost": 0.3
}
}
],
"minimum_should_match": 1,
**({"filter": filter_clauses} if filter_clauses else {})
}
}
search_body = {
"query": query_block,
"aggs": {
"data_sources": {
"terms": {
"field": "filename",
"size": 20
}
},
"document_types": {
"terms": {
"field": "mimetype",
"size": 10
}
},
"owners": {
"terms": {
"field": "owner",
"size": 10
}
}
},
"_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"],
"size": limit
}
# Add score threshold only for hybrid (not meaningful for match_all)
if not is_wildcard_match_all and score_threshold > 0:
search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically
if not user_id:
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = clients.create_user_opensearch_client(jwt_token)
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
# Transform results (keep for backward compatibility)
chunks = []
for hit in results["hits"]["hits"]:
chunks.append({
"filename": hit["_source"]["filename"],
"mimetype": hit["_source"]["mimetype"],
"page": hit["_source"]["page"],
"text": hit["_source"]["text"],
"score": hit["_score"],
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner")
})
# Return both transformed results and aggregations
return {
"results": chunks,
"aggregations": results.get("aggregations", {}),
"total": (results.get("hits", {}).get("total", {}).get("value") if isinstance(results.get("hits", {}).get("total"), dict) else results.get("hits", {}).get("total"))
}
async def search(self, query: str, user_id: str = None, jwt_token: str = None, filters: Dict[str, Any] = None, limit: int = 10, score_threshold: float = 0) -> Dict[str, Any]:
"""Public search method for API endpoints"""
# Set auth context if provided (for direct API calls)
if user_id and jwt_token:
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
# Set filters and limit in context if provided
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
return await self.search_tool(query)