171 lines
No EOL
7 KiB
Python
171 lines
No EOL
7 KiB
Python
from typing import Any, Dict, Optional
|
|
from agentd.tool_decorator import tool
|
|
from config.settings import clients, INDEX_NAME, EMBED_MODEL
|
|
from auth_context import get_auth_context
|
|
|
|
class SearchService:
|
|
def __init__(self, session_manager=None):
|
|
self.session_manager = session_manager
|
|
|
|
@tool
|
|
async def search_tool(self, query: str) -> Dict[str, Any]:
|
|
"""
|
|
Use this tool to search for documents relevant to the query.
|
|
|
|
Args:
|
|
query (str): query string to search the corpus
|
|
|
|
Returns:
|
|
dict (str, Any): {"results": [chunks]} on success
|
|
"""
|
|
# Get authentication context from the current async context
|
|
user_id, jwt_token = get_auth_context()
|
|
# Get search filters, limit, and score threshold from context
|
|
from auth_context import get_search_filters, get_search_limit, get_score_threshold
|
|
filters = get_search_filters() or {}
|
|
limit = get_search_limit()
|
|
score_threshold = get_score_threshold()
|
|
# Detect wildcard request ("*") to return global facets/stats without semantic search
|
|
is_wildcard_match_all = isinstance(query, str) and query.strip() == "*"
|
|
|
|
# Only embed when not doing match_all
|
|
if not is_wildcard_match_all:
|
|
resp = await clients.patched_async_client.embeddings.create(model=EMBED_MODEL, input=[query])
|
|
query_embedding = resp.data[0].embedding
|
|
|
|
# Build filter clauses
|
|
filter_clauses = []
|
|
if filters:
|
|
# Map frontend filter names to backend field names
|
|
field_mapping = {
|
|
"data_sources": "filename",
|
|
"document_types": "mimetype",
|
|
"owners": "owner"
|
|
}
|
|
|
|
for filter_key, values in filters.items():
|
|
if values is not None and isinstance(values, list):
|
|
# Map frontend key to backend field name
|
|
field_name = field_mapping.get(filter_key, filter_key)
|
|
|
|
if len(values) == 0:
|
|
# Empty array means "match nothing" - use impossible filter
|
|
filter_clauses.append({"term": {field_name: "__IMPOSSIBLE_VALUE__"}})
|
|
elif len(values) == 1:
|
|
# Single value filter
|
|
filter_clauses.append({"term": {field_name: values[0]}})
|
|
else:
|
|
# Multiple values filter
|
|
filter_clauses.append({"terms": {field_name: values}})
|
|
|
|
# Build query body
|
|
if is_wildcard_match_all:
|
|
# Match all documents; still allow filters to narrow scope
|
|
if filter_clauses:
|
|
query_block = {"bool": {"filter": filter_clauses}}
|
|
else:
|
|
query_block = {"match_all": {}}
|
|
else:
|
|
# Hybrid search query structure (semantic + keyword)
|
|
query_block = {
|
|
"bool": {
|
|
"should": [
|
|
{
|
|
"knn": {
|
|
"chunk_embedding": {
|
|
"vector": query_embedding,
|
|
"k": 10,
|
|
"boost": 0.7
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"multi_match": {
|
|
"query": query,
|
|
"fields": ["text^2", "filename^1.5"],
|
|
"type": "best_fields",
|
|
"fuzziness": "AUTO",
|
|
"boost": 0.3
|
|
}
|
|
}
|
|
],
|
|
"minimum_should_match": 1,
|
|
**({"filter": filter_clauses} if filter_clauses else {})
|
|
}
|
|
}
|
|
|
|
search_body = {
|
|
"query": query_block,
|
|
"aggs": {
|
|
"data_sources": {
|
|
"terms": {
|
|
"field": "filename",
|
|
"size": 20
|
|
}
|
|
},
|
|
"document_types": {
|
|
"terms": {
|
|
"field": "mimetype",
|
|
"size": 10
|
|
}
|
|
},
|
|
"owners": {
|
|
"terms": {
|
|
"field": "owner",
|
|
"size": 10
|
|
}
|
|
}
|
|
},
|
|
"_source": ["filename", "mimetype", "page", "text", "source_url", "owner", "allowed_users", "allowed_groups"],
|
|
"size": limit
|
|
}
|
|
|
|
# Add score threshold only for hybrid (not meaningful for match_all)
|
|
if not is_wildcard_match_all and score_threshold > 0:
|
|
search_body["min_score"] = score_threshold
|
|
|
|
# Authentication required - DLS will handle document filtering automatically
|
|
if not user_id:
|
|
return {"results": [], "error": "Authentication required"}
|
|
|
|
# Get user's OpenSearch client with JWT for OIDC auth
|
|
opensearch_client = clients.create_user_opensearch_client(jwt_token)
|
|
results = await opensearch_client.search(index=INDEX_NAME, body=search_body)
|
|
|
|
# Transform results (keep for backward compatibility)
|
|
chunks = []
|
|
for hit in results["hits"]["hits"]:
|
|
chunks.append({
|
|
"filename": hit["_source"]["filename"],
|
|
"mimetype": hit["_source"]["mimetype"],
|
|
"page": hit["_source"]["page"],
|
|
"text": hit["_source"]["text"],
|
|
"score": hit["_score"],
|
|
"source_url": hit["_source"].get("source_url"),
|
|
"owner": hit["_source"].get("owner")
|
|
})
|
|
|
|
# Return both transformed results and aggregations
|
|
return {
|
|
"results": chunks,
|
|
"aggregations": results.get("aggregations", {}),
|
|
"total": (results.get("hits", {}).get("total", {}).get("value") if isinstance(results.get("hits", {}).get("total"), dict) else results.get("hits", {}).get("total"))
|
|
}
|
|
|
|
async def search(self, query: str, user_id: str = None, jwt_token: str = None, filters: Dict[str, Any] = None, limit: int = 10, score_threshold: float = 0) -> Dict[str, Any]:
|
|
"""Public search method for API endpoints"""
|
|
# Set auth context if provided (for direct API calls)
|
|
if user_id and jwt_token:
|
|
from auth_context import set_auth_context
|
|
set_auth_context(user_id, jwt_token)
|
|
|
|
# Set filters and limit in context if provided
|
|
if filters:
|
|
from auth_context import set_search_filters
|
|
set_search_filters(filters)
|
|
|
|
from auth_context import set_search_limit, set_score_threshold
|
|
set_search_limit(limit)
|
|
set_score_threshold(score_threshold)
|
|
|
|
return await self.search_tool(query) |