from typing import Any, Dict, Optional from agentd.tool_decorator import tool from config.settings import clients, INDEX_NAME, EMBED_MODEL from auth_context import get_auth_context class SearchService: def __init__(self, session_manager=None): self.session_manager = session_manager @tool async def search_tool(self, query: str) -> Dict[str, Any]: """ Use this tool to search for documents relevant to the query. Args: query (str): query string to search the corpus Returns: dict (str, Any): {"results": [chunks]} on success """ # Get authentication context from the current async context user_id, jwt_token = get_auth_context() # Get search filters, limit, and score threshold from context from auth_context import ( get_search_filters, get_search_limit, get_score_threshold, ) filters = get_search_filters() or {} limit = get_search_limit() score_threshold = get_score_threshold() # Detect wildcard request ("*") to return global facets/stats without semantic search is_wildcard_match_all = isinstance(query, str) and query.strip() == "*" # Only embed when not doing match_all if not is_wildcard_match_all: resp = await clients.patched_async_client.embeddings.create( model=EMBED_MODEL, input=[query] ) query_embedding = resp.data[0].embedding # Build filter clauses filter_clauses = [] if filters: # Map frontend filter names to backend field names field_mapping = { "data_sources": "filename", "document_types": "mimetype", "owners": "owner_name.keyword", "connector_types": "connector_type", } for filter_key, values in filters.items(): if values is not None and isinstance(values, list): # Map frontend key to backend field name field_name = field_mapping.get(filter_key, filter_key) if len(values) == 0: # Empty array means "match nothing" - use impossible filter filter_clauses.append( {"term": {field_name: "__IMPOSSIBLE_VALUE__"}} ) elif len(values) == 1: # Single value filter filter_clauses.append({"term": {field_name: values[0]}}) else: # Multiple values filter filter_clauses.append({"terms": {field_name: values}}) # Build query body if is_wildcard_match_all: # Match all documents; still allow filters to narrow scope if filter_clauses: query_block = {"bool": {"filter": filter_clauses}} else: query_block = {"match_all": {}} else: # Hybrid search query structure (semantic + keyword) query_block = { "bool": { "should": [ { "knn": { "chunk_embedding": { "vector": query_embedding, "k": 10, "boost": 0.7, } } }, { "multi_match": { "query": query, "fields": ["text^2", "filename^1.5"], "type": "best_fields", "fuzziness": "AUTO", "boost": 0.3, } }, ], "minimum_should_match": 1, **({"filter": filter_clauses} if filter_clauses else {}), } } search_body = { "query": query_block, "aggs": { "data_sources": {"terms": {"field": "filename", "size": 20}}, "document_types": {"terms": {"field": "mimetype", "size": 10}}, "owners": {"terms": {"field": "owner_name.keyword", "size": 10}}, "connector_types": {"terms": {"field": "connector_type", "size": 10}}, }, "_source": [ "filename", "mimetype", "page", "text", "source_url", "owner", "owner_name", "owner_email", "file_size", "connector_type", "allowed_users", "allowed_groups", ], "size": limit, } # Add score threshold only for hybrid (not meaningful for match_all) if not is_wildcard_match_all and score_threshold > 0: search_body["min_score"] = score_threshold # Authentication required - DLS will handle document filtering automatically print( f"[DEBUG] search_service: user_id={user_id}, jwt_token={'None' if jwt_token is None else 'present'}" ) if not user_id: print( f"[DEBUG] search_service: user_id is None/empty, returning auth error" ) return {"results": [], "error": "Authentication required"} # Get user's OpenSearch client with JWT for OIDC auth through session manager opensearch_client = self.session_manager.get_user_opensearch_client( user_id, jwt_token ) try: results = await opensearch_client.search(index=INDEX_NAME, body=search_body) except Exception as e: print(f"[ERROR] OpenSearch query failed: {e}") print(f"[ERROR] Search body: {search_body}") # Re-raise the exception so the API returns the error to frontend raise # Transform results (keep for backward compatibility) chunks = [] for hit in results["hits"]["hits"]: chunks.append( { "filename": hit["_source"]["filename"], "mimetype": hit["_source"]["mimetype"], "page": hit["_source"]["page"], "text": hit["_source"]["text"], "score": hit["_score"], "source_url": hit["_source"].get("source_url"), "owner": hit["_source"].get("owner"), "owner_name": hit["_source"].get("owner_name"), "owner_email": hit["_source"].get("owner_email"), "file_size": hit["_source"].get("file_size"), "connector_type": hit["_source"].get("connector_type"), } ) # Return both transformed results and aggregations return { "results": chunks, "aggregations": results.get("aggregations", {}), "total": ( results.get("hits", {}).get("total", {}).get("value") if isinstance(results.get("hits", {}).get("total"), dict) else results.get("hits", {}).get("total") ), } async def search( self, query: str, user_id: str = None, jwt_token: str = None, filters: Dict[str, Any] = None, limit: int = 10, score_threshold: float = 0, ) -> Dict[str, Any]: """Public search method for API endpoints""" # Set auth context if provided (for direct API calls) from config.settings import is_no_auth_mode if user_id and (jwt_token or is_no_auth_mode()): from auth_context import set_auth_context set_auth_context(user_id, jwt_token) # Set filters and limit in context if provided if filters: from auth_context import set_search_filters set_search_filters(filters) from auth_context import set_search_limit, set_score_threshold set_search_limit(limit) set_score_threshold(score_threshold) return await self.search_tool(query)