139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""Search tool for OpenRAG MCP server."""
|
|
|
|
import logging
|
|
|
|
from mcp.types import TextContent, Tool
|
|
|
|
from openrag_sdk import (
|
|
AuthenticationError,
|
|
OpenRAGError,
|
|
RateLimitError,
|
|
SearchFilters,
|
|
ServerError,
|
|
ValidationError,
|
|
)
|
|
|
|
from openrag_mcp.config import get_openrag_client
|
|
from openrag_mcp.tools.registry import register_tool
|
|
|
|
logger = logging.getLogger("openrag-mcp.search")
|
|
|
|
|
|
# Tool definition
|
|
SEARCH_TOOL = Tool(
|
|
name="openrag_search",
|
|
description=(
|
|
"Search the OpenRAG knowledge base using semantic search. "
|
|
"Returns matching document chunks with relevance scores. "
|
|
"Optionally filter by data sources or document types."
|
|
),
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query",
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Maximum number of results (default: 10)",
|
|
"default": 10,
|
|
},
|
|
"score_threshold": {
|
|
"type": "number",
|
|
"description": "Minimum relevance score threshold (default: 0)",
|
|
"default": 0,
|
|
},
|
|
"filter_id": {
|
|
"type": "string",
|
|
"description": "Optional knowledge filter ID to apply",
|
|
},
|
|
"data_sources": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Optional list of filenames to filter by",
|
|
},
|
|
"document_types": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
)
|
|
|
|
|
|
async def handle_search(arguments: dict) -> list[TextContent]:
|
|
"""Handle openrag_search tool calls."""
|
|
query = arguments.get("query", "")
|
|
limit = arguments.get("limit", 10)
|
|
score_threshold = arguments.get("score_threshold", 0)
|
|
filter_id = arguments.get("filter_id")
|
|
data_sources = arguments.get("data_sources")
|
|
document_types = arguments.get("document_types")
|
|
|
|
if not query:
|
|
return [TextContent(type="text", text="Error: query is required")]
|
|
|
|
try:
|
|
client = get_openrag_client()
|
|
|
|
# Build filters if provided
|
|
filters = None
|
|
if data_sources or document_types:
|
|
filters = SearchFilters(
|
|
data_sources=data_sources,
|
|
document_types=document_types,
|
|
)
|
|
|
|
response = await client.search.query(
|
|
query=query,
|
|
limit=limit,
|
|
score_threshold=score_threshold,
|
|
filter_id=filter_id,
|
|
filters=filters,
|
|
)
|
|
|
|
if not response.results:
|
|
return [TextContent(type="text", text="No results found.")]
|
|
|
|
# Format results
|
|
output_parts = [f"Found {len(response.results)} result(s):\n"]
|
|
|
|
for i, result in enumerate(response.results, 1):
|
|
output_parts.append(f"\n---\n**{i}. {result.filename}**")
|
|
if result.page:
|
|
output_parts.append(f" (page {result.page})")
|
|
output_parts.append(f"\nRelevance: {result.score:.2f}\n")
|
|
|
|
# Truncate long content
|
|
content = result.text
|
|
if len(content) > 500:
|
|
content = content[:500] + "..."
|
|
output_parts.append(f"\n{content}\n")
|
|
|
|
return [TextContent(type="text", text="".join(output_parts))]
|
|
|
|
except AuthenticationError as e:
|
|
logger.error(f"Authentication error: {e.message}")
|
|
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
|
except ValidationError as e:
|
|
logger.error(f"Validation error: {e.message}")
|
|
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
|
except RateLimitError as e:
|
|
logger.error(f"Rate limit error: {e.message}")
|
|
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
|
except ServerError as e:
|
|
logger.error(f"Server error: {e.message}")
|
|
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
|
except OpenRAGError as e:
|
|
logger.error(f"OpenRAG error: {e.message}")
|
|
return [TextContent(type="text", text=f"Error: {e.message}")]
|
|
except Exception as e:
|
|
logger.error(f"Search error: {e}")
|
|
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
|
|
|
|
|
# Register the tool
|
|
register_tool(SEARCH_TOOL, handle_search)
|