openrag/sdks/mcp/src/openrag_mcp/tools/search.py
2025-12-26 14:21:03 -05:00

139 lines
4.5 KiB
Python

"""Search tool for OpenRAG MCP server."""
import logging
from mcp.types import TextContent, Tool
from openrag_sdk import (
AuthenticationError,
OpenRAGError,
RateLimitError,
SearchFilters,
ServerError,
ValidationError,
)
from openrag_mcp.config import get_openrag_client
from openrag_mcp.tools.registry import register_tool
logger = logging.getLogger("openrag-mcp.search")
# Tool definition
SEARCH_TOOL = Tool(
name="openrag_search",
description=(
"Search the OpenRAG knowledge base using semantic search. "
"Returns matching document chunks with relevance scores. "
"Optionally filter by data sources or document types."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query",
},
"limit": {
"type": "integer",
"description": "Maximum number of results (default: 10)",
"default": 10,
},
"score_threshold": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0)",
"default": 0,
},
"filter_id": {
"type": "string",
"description": "Optional knowledge filter ID to apply",
},
"data_sources": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of filenames to filter by",
},
"document_types": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
},
},
"required": ["query"],
},
)
async def handle_search(arguments: dict) -> list[TextContent]:
"""Handle openrag_search tool calls."""
query = arguments.get("query", "")
limit = arguments.get("limit", 10)
score_threshold = arguments.get("score_threshold", 0)
filter_id = arguments.get("filter_id")
data_sources = arguments.get("data_sources")
document_types = arguments.get("document_types")
if not query:
return [TextContent(type="text", text="Error: query is required")]
try:
client = get_openrag_client()
# Build filters if provided
filters = None
if data_sources or document_types:
filters = SearchFilters(
data_sources=data_sources,
document_types=document_types,
)
response = await client.search.query(
query=query,
limit=limit,
score_threshold=score_threshold,
filter_id=filter_id,
filters=filters,
)
if not response.results:
return [TextContent(type="text", text="No results found.")]
# Format results
output_parts = [f"Found {len(response.results)} result(s):\n"]
for i, result in enumerate(response.results, 1):
output_parts.append(f"\n---\n**{i}. {result.filename}**")
if result.page:
output_parts.append(f" (page {result.page})")
output_parts.append(f"\nRelevance: {result.score:.2f}\n")
# Truncate long content
content = result.text
if len(content) > 500:
content = content[:500] + "..."
output_parts.append(f"\n{content}\n")
return [TextContent(type="text", text="".join(output_parts))]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except ValidationError as e:
logger.error(f"Validation error: {e.message}")
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
except RateLimitError as e:
logger.error(f"Rate limit error: {e.message}")
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
except ServerError as e:
logger.error(f"Server error: {e.message}")
return [TextContent(type="text", text=f"Server error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e:
logger.error(f"Search error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]
# Register the tool
register_tool(SEARCH_TOOL, handle_search)