This commit updates the SearchService class to utilize the `get` method for safely accessing fields in the search results. This change enhances the robustness of the code by preventing potential KeyErrors and aligns with best practices for building maintainable async code. Additionally, it simplifies the data extraction process from the search results, improving overall code clarity.
223 lines
8.3 KiB
Python
223 lines
8.3 KiB
Python
from typing import Any, Dict
|
|
from agentd.tool_decorator import tool
|
|
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
|
from auth_context import get_auth_context
|
|
from utils.logging_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SearchService:
|
|
def __init__(self, session_manager=None):
|
|
self.session_manager = session_manager
|
|
|
|
@tool
|
|
async def search_tool(self, query: str) -> Dict[str, Any]:
|
|
"""
|
|
Use this tool to search for documents relevant to the query.
|
|
|
|
Args:
|
|
query (str): query string to search the corpus
|
|
|
|
Returns:
|
|
dict (str, Any): {"results": [chunks]} on success
|
|
"""
|
|
# Get authentication context from the current async context
|
|
user_id, jwt_token = get_auth_context()
|
|
# Get search filters, limit, and score threshold from context
|
|
from auth_context import (
|
|
get_search_filters,
|
|
get_search_limit,
|
|
get_score_threshold,
|
|
)
|
|
|
|
filters = get_search_filters() or {}
|
|
limit = get_search_limit()
|
|
score_threshold = get_score_threshold()
|
|
# Detect wildcard request ("*") to return global facets/stats without semantic search
|
|
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
|
|
|
|
# Only embed when not doing match_all
|
|
if not is_wildcard_match_all:
|
|
resp = await clients.patched_async_client.embeddings.create(
|
|
model=EMBED_MODEL, input=[query]
|
|
)
|
|
query_embedding = resp.data[0].embedding
|
|
|
|
# Build filter clauses
|
|
filter_clauses = []
|
|
if filters:
|
|
# Map frontend filter names to backend field names
|
|
field_mapping = {
|
|
"data_sources": "filename",
|
|
"document_types": "mimetype",
|
|
"owners": "owner_name.keyword",
|
|
"connector_types": "connector_type",
|
|
}
|
|
|
|
for filter_key, values in filters.items():
|
|
if values is not None and isinstance(values, list):
|
|
# Map frontend key to backend field name
|
|
field_name = field_mapping.get(filter_key, filter_key)
|
|
|
|
if len(values) == 0:
|
|
# Empty array means "match nothing" - use impossible filter
|
|
filter_clauses.append(
|
|
{"term": {field_name: "__IMPOSSIBLE_VALUE__"}}
|
|
)
|
|
elif len(values) == 1:
|
|
# Single value filter
|
|
filter_clauses.append({"term": {field_name: values[0]}})
|
|
else:
|
|
# Multiple values filter
|
|
filter_clauses.append({"terms": {field_name: values}})
|
|
|
|
# Build query body
|
|
if is_wildcard_match_all:
|
|
# Match all documents; still allow filters to narrow scope
|
|
if filter_clauses:
|
|
query_block = {"bool": {"filter": filter_clauses}}
|
|
else:
|
|
query_block = {"match_all": {}}
|
|
else:
|
|
# Hybrid search query structure (semantic + keyword)
|
|
query_block = {
|
|
"bool": {
|
|
"should": [
|
|
{
|
|
"knn": {
|
|
"chunk_embedding": {
|
|
"vector": query_embedding,
|
|
"k": 10,
|
|
"boost": 0.7,
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"multi_match": {
|
|
"query": query,
|
|
"fields": ["text^2", "filename^1.5"],
|
|
"type": "best_fields",
|
|
"fuzziness": "AUTO",
|
|
"boost": 0.3,
|
|
}
|
|
},
|
|
],
|
|
"minimum_should_match": 1,
|
|
**({"filter": filter_clauses} if filter_clauses else {}),
|
|
}
|
|
}
|
|
|
|
search_body = {
|
|
"query": query_block,
|
|
"aggs": {
|
|
"data_sources": {"terms": {"field": "filename", "size": 20}},
|
|
"document_types": {"terms": {"field": "mimetype", "size": 10}},
|
|
"owners": {"terms": {"field": "owner_name.keyword", "size": 10}},
|
|
"connector_types": {"terms": {"field": "connector_type", "size": 10}},
|
|
},
|
|
"_source": [
|
|
"filename",
|
|
"mimetype",
|
|
"page",
|
|
"text",
|
|
"source_url",
|
|
"owner",
|
|
"owner_name",
|
|
"owner_email",
|
|
"file_size",
|
|
"connector_type",
|
|
"allowed_users",
|
|
"allowed_groups",
|
|
],
|
|
"size": limit,
|
|
}
|
|
|
|
# Add score threshold only for hybrid (not meaningful for match_all)
|
|
if not is_wildcard_match_all and score_threshold > 0:
|
|
search_body["min_score"] = score_threshold
|
|
|
|
# Authentication required - DLS will handle document filtering automatically
|
|
logger.debug(
|
|
"search_service authentication info",
|
|
user_id=user_id,
|
|
has_jwt_token=jwt_token is not None,
|
|
)
|
|
if not user_id:
|
|
logger.debug("search_service: user_id is None/empty, returning auth error")
|
|
return {"results": [], "error": "Authentication required"}
|
|
|
|
# Get user's OpenSearch client with JWT for OIDC auth through session manager
|
|
opensearch_client = self.session_manager.get_user_opensearch_client(
|
|
user_id, jwt_token
|
|
)
|
|
|
|
try:
|
|
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
|
except Exception as e:
|
|
logger.error(
|
|
"OpenSearch query failed", error=str(e), search_body=search_body
|
|
)
|
|
# Re-raise the exception so the API returns the error to frontend
|
|
raise
|
|
|
|
# Transform results (keep for backward compatibility)
|
|
chunks = []
|
|
for hit in results["hits"]["hits"]:
|
|
chunks.append(
|
|
{
|
|
"filename": hit["_source"].get("filename"),
|
|
"mimetype": hit["_source"].get("mimetype"),
|
|
"page": hit["_source"].get("page"),
|
|
"text": hit["_source"].get("text"),
|
|
"score": hit.get("_score"),
|
|
"source_url": hit["_source"].get("source_url"),
|
|
"owner": hit["_source"].get("owner"),
|
|
"owner_name": hit["_source"].get("owner_name"),
|
|
"owner_email": hit["_source"].get("owner_email"),
|
|
"file_size": hit["_source"].get("file_size"),
|
|
"connector_type": hit["_source"].get("connector_type"),
|
|
}
|
|
)
|
|
|
|
# Return both transformed results and aggregations
|
|
return {
|
|
"results": chunks,
|
|
"aggregations": results.get("aggregations", {}),
|
|
"total": (
|
|
results.get("hits", {}).get("total", {}).get("value")
|
|
if isinstance(results.get("hits", {}).get("total"), dict)
|
|
else results.get("hits", {}).get("total")
|
|
),
|
|
}
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
user_id: str = None,
|
|
jwt_token: str = None,
|
|
filters: Dict[str, Any] = None,
|
|
limit: int = 10,
|
|
score_threshold: float = 0,
|
|
) -> Dict[str, Any]:
|
|
"""Public search method for API endpoints"""
|
|
# Set auth context if provided (for direct API calls)
|
|
from config.settings import is_no_auth_mode
|
|
|
|
if user_id and (jwt_token or is_no_auth_mode()):
|
|
from auth_context import set_auth_context
|
|
|
|
set_auth_context(user_id, jwt_token)
|
|
|
|
# Set filters and limit in context if provided
|
|
if filters:
|
|
from auth_context import set_search_filters
|
|
|
|
set_search_filters(filters)
|
|
|
|
from auth_context import set_search_limit, set_score_threshold
|
|
|
|
set_search_limit(limit)
|
|
set_score_threshold(score_threshold)
|
|
|
|
return await self.search_tool(query)
|