openrag/src/services/search_service.py
Gabriel Luiz Freitas Almeida 12cbd63dbf Refactor SearchService to improve data retrieval from search results
This commit updates the SearchService class to utilize the `get` method for safely accessing fields in the search results. This change enhances the robustness of the code by preventing potential KeyErrors and aligns with best practices for building maintainable async code. Additionally, it simplifies the data extraction process from the search results, improving overall code clarity.
2025-09-08 20:10:48 -03:00

223 lines
8.3 KiB
Python

from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import clients, INDEX_NAME, EMBED_MODEL
from auth_context import get_auth_context
from utils.logging_config import get_logger
logger = get_logger(__name__)
class SearchService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
@tool
async def search_tool(self, query: str) -> Dict[str, Any]:
"""
Use this tool to search for documents relevant to the query.
Args:
query (str): query string to search the corpus
Returns:
dict (str, Any): {"results": [chunks]} on success
"""
# Get authentication context from the current async context
user_id, jwt_token = get_auth_context()
# Get search filters, limit, and score threshold from context
from auth_context import (
get_search_filters,
get_search_limit,
get_score_threshold,
)
filters = get_search_filters() or {}
limit = get_search_limit()
score_threshold = get_score_threshold()
# Detect wildcard request ("*") to return global facets/stats without semantic search
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
# Only embed when not doing match_all
if not is_wildcard_match_all:
resp = await clients.patched_async_client.embeddings.create(
model=EMBED_MODEL, input=[query]
)
query_embedding = resp.data[0].embedding
# Build filter clauses
filter_clauses = []
if filters:
# Map frontend filter names to backend field names
field_mapping = {
"data_sources": "filename",
"document_types": "mimetype",
"owners": "owner_name.keyword",
"connector_types": "connector_type",
}
for filter_key, values in filters.items():
if values is not None and isinstance(values, list):
# Map frontend key to backend field name
field_name = field_mapping.get(filter_key, filter_key)
if len(values) == 0:
# Empty array means "match nothing" - use impossible filter
filter_clauses.append(
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
)
elif len(values) == 1:
# Single value filter
filter_clauses.append({"term": {field_name: values[0]}})
else:
# Multiple values filter
filter_clauses.append({"terms": {field_name: values}})
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
else:
query_block = {"match_all": {}}
else:
# Hybrid search query structure (semantic + keyword)
query_block = {
"bool": {
"should": [
{
"knn": {
"chunk_embedding": {
"vector": query_embedding,
"k": 10,
"boost": 0.7,
}
}
},
{
"multi_match": {
"query": query,
"fields": ["text^2", "filename^1.5"],
"type": "best_fields",
"fuzziness": "AUTO",
"boost": 0.3,
}
},
],
"minimum_should_match": 1,
**({"filter": filter_clauses} if filter_clauses else {}),
}
}
search_body = {
"query": query_block,
"aggs": {
"data_sources": {"terms": {"field": "filename", "size": 20}},
"document_types": {"terms": {"field": "mimetype", "size": 10}},
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
},
"_source": [
"filename",
"mimetype",
"page",
"text",
"source_url",
"owner",
"owner_name",
"owner_email",
"file_size",
"connector_type",
"allowed_users",
"allowed_groups",
],
"size": limit,
}
# Add score threshold only for hybrid (not meaningful for match_all)
if not is_wildcard_match_all and score_threshold > 0:
search_body["min_score"] = score_threshold
# Authentication required - DLS will handle document filtering automatically
logger.debug(
"search_service authentication info",
user_id=user_id,
has_jwt_token=jwt_token is not None,
)
if not user_id:
logger.debug("search_service: user_id is None/empty, returning auth error")
return {"results": [], "error": "Authentication required"}
# Get user's OpenSearch client with JWT for OIDC auth through session manager
opensearch_client = self.session_manager.get_user_opensearch_client(
user_id, jwt_token
)
try:
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
except Exception as e:
logger.error(
"OpenSearch query failed", error=str(e), search_body=search_body
)
# Re-raise the exception so the API returns the error to frontend
raise
# Transform results (keep for backward compatibility)
chunks = []
for hit in results["hits"]["hits"]:
chunks.append(
{
"filename": hit["_source"].get("filename"),
"mimetype": hit["_source"].get("mimetype"),
"page": hit["_source"].get("page"),
"text": hit["_source"].get("text"),
"score": hit.get("_score"),
"source_url": hit["_source"].get("source_url"),
"owner": hit["_source"].get("owner"),
"owner_name": hit["_source"].get("owner_name"),
"owner_email": hit["_source"].get("owner_email"),
"file_size": hit["_source"].get("file_size"),
"connector_type": hit["_source"].get("connector_type"),
}
)
# Return both transformed results and aggregations
return {
"results": chunks,
"aggregations": results.get("aggregations", {}),
"total": (
results.get("hits", {}).get("total", {}).get("value")
if isinstance(results.get("hits", {}).get("total"), dict)
else results.get("hits", {}).get("total")
),
}
async def search(
self,
query: str,
user_id: str = None,
jwt_token: str = None,
filters: Dict[str, Any] = None,
limit: int = 10,
score_threshold: float = 0,
) -> Dict[str, Any]:
"""Public search method for API endpoints"""
# Set auth context if provided (for direct API calls)
from config.settings import is_no_auth_mode
if user_id and (jwt_token or is_no_auth_mode()):
from auth_context import set_auth_context
set_auth_context(user_id, jwt_token)
# Set filters and limit in context if provided
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
return await self.search_tool(query)