openrag/src/services/search_service.py
2025-09-03 09:17:30 -04:00

219 lines
8.2 KiB
Python

from typing import Any, Dict, Optional
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from auth_context import get_auth_context
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool
async def search_tool(self, query: str) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Get search filters, limit, and score threshold from context
from auth_context import (
get_search_filters,
get_search_limit,
get_score_threshold,
)
filters = get_search_filters() or {}
limit = get_search_limit()
score_threshold = get_score_threshold()
# Detect wildcard request ("*") to return global facets/stats without semantic search
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
# Only embed when not doing match_all
if not is_wildcard_match_all:
resp = await clients.patched_async_client.embeddings.create(
model=EMBED_MODEL, input=[query]
)
query_embedding = resp.data[0].embedding
# Build filter clauses
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
else:
query_block = {"match_all": {}}
else:
# Hybrid search query structure (semantic + keyword)
query_block = {
"bool": {
"should": [
{
"knn": {
"chunk_embedding": {
"vector": query_embedding,
"k": 10,
"boost": 0.7,
}
}
},
{
"multi_match": {
"query": query,
"fields": ["text^2", "filename^1.5"],
"type": "best_fields",
"fuzziness": "AUTO",
"boost": 0.3,
}
},
],
"minimum_should_match": 1,
**({"filter": filter_clauses} if filter_clauses else {}),
}
}
search_body = {
"query": query_block,
"aggs": {
"data_sources": {"terms": {"field": "filename", "size": 20}},
"document_types": {"terms": {"field": "mimetype", "size": 10}},
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
},
"_source": [
"filename",
"mimetype",
"page",
"text",
"source_url",
"owner",
"owner_name",
"owner_email",
"file_size",
"connector_type",
"allowed_users",
"allowed_groups",
],
"size": limit,
}
# Add score threshold only for hybrid (not meaningful for match_all)
if not is_wildcard_match_all and score_threshold > 0:
search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically
print(
f"[DEBUG] search_service: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}"
)
if not user_id:
print(
f"[DEBUG] search_service: user_id is None/empty, returning auth error"
)
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth through session manager
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e:
print(f"[ERROR] OpenSearch query failed: {e}")
print(f"[ERROR] Search body: {search_body}")
# Re-raise the exception so the API returns the error to frontend
raise
# Transform results (keep for backward compatibility)
chunks = []
for hit in results["hits"]["hits"]:
chunks.append(
{
"filename": hit["_source"]["filename"],
"mimetype": hit["_source"]["mimetype"],
"page": hit["_source"]["page"],
"text": hit["_source"]["text"],
"score": hit["_score"],
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner"),
"owner_name": hit["_source"].get("owner_name"),
"owner_email": hit["_source"].get("owner_email"),
"file_size": hit["_source"].get("file_size"),
"connector_type": hit["_source"].get("connector_type"),
}
)
# Return both transformed results and aggregations
return {
"results": chunks,
"aggregations": results.get("aggregations", {}),
"total": (
results.get("hits", {}).get("total", {}).get("value")
if isinstance(results.get("hits", {}).get("total"), dict)
else results.get("hits", {}).get("total")
),
}
async def search(
self,
query: str,
user_id: str = None,
jwt_token: str = None,
filters: Dict[str, Any] = None,
limit: int = 10,
score_threshold: float = 0,
) -> Dict[str, Any]:
"""Public search method for API endpoints"""
# Set auth context if provided (for direct API calls)
from config.settings import is_no_auth_mode
if user_id and (jwt_token or is_no_auth_mode()):
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
# Set filters and limit in context if provided
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
return await self.search_tool(query)