update to add tool registry
This commit is contained in:
parent
a6a473aff0
commit
91f90d9d9d
7 changed files with 292 additions and 258 deletions
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "openrag-mcp"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "MCP server for OpenRAG - Chat, Search, and Ingest documents"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""OpenRAG MCP Server - Main server setup and entry point."""
|
||||
#TODO: utilize the SDK directly so that any changes in parameters in SDK directky reflects the MCP
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
|
@ -8,9 +9,9 @@ from mcp.server.stdio import stdio_server
|
|||
from mcp.types import TextContent, Tool
|
||||
|
||||
from openrag_mcp.config import get_config
|
||||
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
|
||||
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
|
||||
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
|
||||
|
||||
# Import tools module to trigger registration, then get registry functions
|
||||
from openrag_mcp.tools import get_all_tools, get_handler
|
||||
|
||||
# Configure logging to stderr (stdout is used for MCP protocol)
|
||||
logging.basicConfig(
|
||||
|
|
@ -30,34 +31,17 @@ def create_server() -> Server:
|
|||
# Create server instance
|
||||
server = Server("openrag-mcp")
|
||||
|
||||
# Register a single list_tools handler that combines all tools
|
||||
@server.list_tools()
|
||||
async def list_all_tools() -> list[Tool]:
|
||||
"""List all available tools."""
|
||||
tools = []
|
||||
tools.extend(get_chat_tools())
|
||||
tools.extend(get_search_tools())
|
||||
tools.extend(get_document_tools())
|
||||
return tools
|
||||
return get_all_tools()
|
||||
|
||||
# Register a single call_tool handler that dispatches to appropriate handler
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""Handle all tool calls by dispatching to the appropriate handler."""
|
||||
# Try each handler in order
|
||||
result = await handle_chat_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = await handle_search_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = await handle_document_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Unknown tool
|
||||
handler = get_handler(name)
|
||||
if handler:
|
||||
return await handler(arguments)
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
logger.info("OpenRAG MCP server initialized with all tools")
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
"""OpenRAG MCP tools."""
|
||||
"""OpenRAG MCP tools.
|
||||
|
||||
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
|
||||
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
|
||||
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
|
||||
Import this module to register all tools with the registry.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"get_chat_tools",
|
||||
"handle_chat_tool",
|
||||
"get_search_tools",
|
||||
"handle_search_tool",
|
||||
"get_document_tools",
|
||||
"handle_document_tool",
|
||||
]
|
||||
# Import tools to trigger registration
|
||||
from openrag_mcp.tools import chat # noqa: F401
|
||||
from openrag_mcp.tools import search # noqa: F401
|
||||
from openrag_mcp.tools import documents # noqa: F401
|
||||
|
||||
# Re-export registry functions for convenience
|
||||
from openrag_mcp.tools.registry import get_all_tools, get_handler
|
||||
|
||||
__all__ = ["get_all_tools", "get_handler"]
|
||||
|
|
|
|||
|
|
@ -13,58 +13,53 @@ from openrag_sdk import (
|
|||
)
|
||||
|
||||
from openrag_mcp.config import get_openrag_client
|
||||
from openrag_mcp.tools.registry import register_tool
|
||||
|
||||
logger = logging.getLogger("openrag-mcp.chat")
|
||||
|
||||
|
||||
def get_chat_tools() -> list[Tool]:
|
||||
"""Return chat-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_chat",
|
||||
description=(
|
||||
"Send a message to OpenRAG and get a RAG-enhanced response. "
|
||||
"The response is informed by documents in your knowledge base. "
|
||||
"Use chat_id to continue a previous conversation, or filter_id "
|
||||
"to apply a knowledge filter."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your question or message to send to OpenRAG",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional conversation ID to continue a previous chat",
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of sources to retrieve (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
# Tool definition
|
||||
CHAT_TOOL = Tool(
|
||||
name="openrag_chat",
|
||||
description=(
|
||||
"Send a message to OpenRAG and get a RAG-enhanced response. "
|
||||
"The response is informed by documents in your knowledge base. "
|
||||
"Use chat_id to continue a previous conversation, or filter_id "
|
||||
"to apply a knowledge filter."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your question or message to send to OpenRAG",
|
||||
},
|
||||
),
|
||||
]
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional conversation ID to continue a previous chat",
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of sources to retrieve (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle chat tool calls. Returns None if tool not handled."""
|
||||
if name != "openrag_chat":
|
||||
return None
|
||||
|
||||
async def handle_chat(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_chat tool calls."""
|
||||
message = arguments.get("message", "")
|
||||
chat_id = arguments.get("chat_id")
|
||||
filter_id = arguments.get("filter_id")
|
||||
|
|
@ -115,3 +110,7 @@ async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | No
|
|||
except Exception as e:
|
||||
logger.error(f"Chat error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
||||
|
||||
# Register the tool
|
||||
register_tool(CHAT_TOOL, handle_chat)
|
||||
|
|
|
|||
|
|
@ -15,126 +15,42 @@ from openrag_sdk import (
|
|||
)
|
||||
|
||||
from openrag_mcp.config import get_openrag_client
|
||||
from openrag_mcp.tools.registry import register_tool
|
||||
|
||||
logger = logging.getLogger("openrag-mcp.documents")
|
||||
|
||||
|
||||
def get_document_tools() -> list[Tool]:
|
||||
"""Return document-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_ingest_file",
|
||||
description=(
|
||||
"Ingest a local file into the OpenRAG knowledge base. "
|
||||
"Supported formats: PDF, DOCX, TXT, MD, HTML, and more. "
|
||||
"By default waits for ingestion to complete. Set wait=false to return immediately."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to ingest",
|
||||
},
|
||||
"wait": {
|
||||
"type": "boolean",
|
||||
"description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
# ============================================================================
|
||||
# Tool: openrag_ingest_file
|
||||
# ============================================================================
|
||||
|
||||
INGEST_FILE_TOOL = Tool(
|
||||
name="openrag_ingest_file",
|
||||
description=(
|
||||
"Ingest a local file into the OpenRAG knowledge base. "
|
||||
"Supported formats: PDF, DOCX, TXT, MD, HTML, and more. "
|
||||
"By default waits for ingestion to complete. Set wait=false to return immediately."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to ingest",
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_ingest_url",
|
||||
description=(
|
||||
"Ingest content from a URL into the OpenRAG knowledge base. "
|
||||
"The URL content will be fetched, processed, and stored."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch and ingest",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
"wait": {
|
||||
"type": "boolean",
|
||||
"description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.",
|
||||
"default": True,
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_get_task_status",
|
||||
description=(
|
||||
"Check the status of an ingestion task. "
|
||||
"Use the task_id returned from openrag_ingest_file when wait=false."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to check status for",
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_wait_for_task",
|
||||
description=(
|
||||
"Wait for an ingestion task to complete. "
|
||||
"Polls the task status until it completes or fails."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to wait for",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Maximum seconds to wait (default: 300)",
|
||||
"default": 300,
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_delete_document",
|
||||
description="Delete a document from the OpenRAG knowledge base.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file to delete",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
),
|
||||
]
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle document tool calls. Returns None if tool not handled."""
|
||||
if name == "openrag_ingest_file":
|
||||
return await _ingest_file(arguments)
|
||||
elif name == "openrag_ingest_url":
|
||||
return await _ingest_url(arguments)
|
||||
elif name == "openrag_get_task_status":
|
||||
return await _get_task_status(arguments)
|
||||
elif name == "openrag_wait_for_task":
|
||||
return await _wait_for_task(arguments)
|
||||
elif name == "openrag_delete_document":
|
||||
return await _delete_document(arguments)
|
||||
return None
|
||||
|
||||
|
||||
async def _ingest_file(arguments: dict) -> list[TextContent]:
|
||||
"""Ingest a local file into OpenRAG using the SDK."""
|
||||
async def handle_ingest_file(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_ingest_file tool calls."""
|
||||
file_path = arguments.get("file_path", "")
|
||||
wait = arguments.get("wait", True)
|
||||
|
||||
|
|
@ -154,11 +70,10 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
|
|||
response = await client.documents.ingest(file_path=path, wait=wait)
|
||||
|
||||
if wait:
|
||||
# Response is IngestTaskStatus when wait=True
|
||||
status = response.status
|
||||
successful = response.successful_files
|
||||
failed = response.failed_files
|
||||
|
||||
|
||||
if status == "completed":
|
||||
result = f"Successfully ingested '{path.name}'."
|
||||
result += f"\nStatus: {status}"
|
||||
|
|
@ -170,7 +85,6 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
|
|||
result += f"\nSuccessful files: {successful}"
|
||||
result += f"\nFailed files: {failed}"
|
||||
else:
|
||||
# Response is IngestResponse when wait=False
|
||||
result = f"Successfully queued '{response.filename or path.name}' for ingestion."
|
||||
if response.task_id:
|
||||
result += f"\nTask ID: {response.task_id}"
|
||||
|
|
@ -201,11 +115,31 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
|
|||
return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
|
||||
|
||||
|
||||
async def _ingest_url(arguments: dict) -> list[TextContent]:
|
||||
"""Ingest content from a URL into OpenRAG.
|
||||
# ============================================================================
|
||||
# Tool: openrag_ingest_url
|
||||
# ============================================================================
|
||||
|
||||
Note: This uses the SDK's chat to trigger URL ingestion via the agent.
|
||||
"""
|
||||
INGEST_URL_TOOL = Tool(
|
||||
name="openrag_ingest_url",
|
||||
description=(
|
||||
"Ingest content from a URL into the OpenRAG knowledge base. "
|
||||
"The URL content will be fetched, processed, and stored."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch and ingest",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_ingest_url(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_ingest_url tool calls."""
|
||||
url = arguments.get("url", "")
|
||||
|
||||
if not url:
|
||||
|
|
@ -215,7 +149,6 @@ async def _ingest_url(arguments: dict) -> list[TextContent]:
|
|||
return [TextContent(type="text", text="Error: url must start with http:// or https://")]
|
||||
|
||||
try:
|
||||
# Use chat with a special prompt to trigger URL ingestion via the agent
|
||||
client = get_openrag_client()
|
||||
response = await client.chat.create(
|
||||
message=f"Please ingest the content from this URL into the knowledge base: {url}",
|
||||
|
|
@ -237,8 +170,31 @@ async def _ingest_url(arguments: dict) -> list[TextContent]:
|
|||
return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
|
||||
|
||||
|
||||
async def _get_task_status(arguments: dict) -> list[TextContent]:
|
||||
"""Get the status of an ingestion task."""
|
||||
# ============================================================================
|
||||
# Tool: openrag_get_task_status
|
||||
# ============================================================================
|
||||
|
||||
GET_TASK_STATUS_TOOL = Tool(
|
||||
name="openrag_get_task_status",
|
||||
description=(
|
||||
"Check the status of an ingestion task. "
|
||||
"Use the task_id returned from openrag_ingest_file when wait=false."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to check status for",
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_get_task_status(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_get_task_status tool calls."""
|
||||
task_id = arguments.get("task_id", "")
|
||||
|
||||
if not task_id:
|
||||
|
|
@ -276,8 +232,36 @@ async def _get_task_status(arguments: dict) -> list[TextContent]:
|
|||
return [TextContent(type="text", text=f"Error getting task status: {str(e)}")]
|
||||
|
||||
|
||||
async def _wait_for_task(arguments: dict) -> list[TextContent]:
|
||||
"""Wait for an ingestion task to complete."""
|
||||
# ============================================================================
|
||||
# Tool: openrag_wait_for_task
|
||||
# ============================================================================
|
||||
|
||||
WAIT_FOR_TASK_TOOL = Tool(
|
||||
name="openrag_wait_for_task",
|
||||
description=(
|
||||
"Wait for an ingestion task to complete. "
|
||||
"Polls the task status until it completes or fails."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to wait for",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Maximum seconds to wait (default: 300)",
|
||||
"default": 300,
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_wait_for_task(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_wait_for_task tool calls."""
|
||||
task_id = arguments.get("task_id", "")
|
||||
timeout = arguments.get("timeout", 300)
|
||||
|
||||
|
|
@ -318,8 +302,28 @@ async def _wait_for_task(arguments: dict) -> list[TextContent]:
|
|||
return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")]
|
||||
|
||||
|
||||
async def _delete_document(arguments: dict) -> list[TextContent]:
|
||||
"""Delete a document from the knowledge base using the SDK."""
|
||||
# ============================================================================
|
||||
# Tool: openrag_delete_document
|
||||
# ============================================================================
|
||||
|
||||
DELETE_DOCUMENT_TOOL = Tool(
|
||||
name="openrag_delete_document",
|
||||
description="Delete a document from the OpenRAG knowledge base.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file to delete",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_delete_document(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_delete_document tool calls."""
|
||||
filename = arguments.get("filename", "")
|
||||
|
||||
if not filename:
|
||||
|
|
@ -355,3 +359,14 @@ async def _delete_document(arguments: dict) -> list[TextContent]:
|
|||
except Exception as e:
|
||||
logger.error(f"Delete document error: {e}")
|
||||
return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Register all tools
|
||||
# ============================================================================
|
||||
|
||||
register_tool(INGEST_FILE_TOOL, handle_ingest_file)
|
||||
register_tool(INGEST_URL_TOOL, handle_ingest_url)
|
||||
register_tool(GET_TASK_STATUS_TOOL, handle_get_task_status)
|
||||
register_tool(WAIT_FOR_TASK_TOOL, handle_wait_for_task)
|
||||
register_tool(DELETE_DOCUMENT_TOOL, handle_delete_document)
|
||||
|
|
|
|||
37
sdks/mcp/src/openrag_mcp/tools/registry.py
Normal file
37
sdks/mcp/src/openrag_mcp/tools/registry.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Tool registry for OpenRAG MCP server."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
# Type alias for tool handlers
|
||||
ToolHandler = Callable[[dict], Awaitable[list[TextContent]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolEntry:
|
||||
"""A tool definition with its handler."""
|
||||
tool: Tool
|
||||
handler: ToolHandler
|
||||
|
||||
|
||||
# Global registry: tool_name -> ToolEntry
|
||||
_registry: dict[str, ToolEntry] = {}
|
||||
|
||||
|
||||
def register_tool(tool: Tool, handler: ToolHandler) -> None:
|
||||
"""Register a tool with its handler."""
|
||||
_registry[tool.name] = ToolEntry(tool=tool, handler=handler)
|
||||
|
||||
|
||||
def get_all_tools() -> list[Tool]:
|
||||
"""Get all registered tools."""
|
||||
return [entry.tool for entry in _registry.values()]
|
||||
|
||||
|
||||
def get_handler(name: str) -> ToolHandler | None:
|
||||
"""Get the handler for a tool by name."""
|
||||
entry = _registry.get(name)
|
||||
return entry.handler if entry else None
|
||||
|
||||
|
|
@ -14,63 +14,58 @@ from openrag_sdk import (
|
|||
)
|
||||
|
||||
from openrag_mcp.config import get_openrag_client
|
||||
from openrag_mcp.tools.registry import register_tool
|
||||
|
||||
logger = logging.getLogger("openrag-mcp.search")
|
||||
|
||||
|
||||
def get_search_tools() -> list[Tool]:
|
||||
"""Return search-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_search",
|
||||
description=(
|
||||
"Search the OpenRAG knowledge base using semantic search. "
|
||||
"Returns matching document chunks with relevance scores. "
|
||||
"Optionally filter by data sources or document types."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"data_sources": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of filenames to filter by",
|
||||
},
|
||||
"document_types": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
# Tool definition
|
||||
SEARCH_TOOL = Tool(
|
||||
name="openrag_search",
|
||||
description=(
|
||||
"Search the OpenRAG knowledge base using semantic search. "
|
||||
"Returns matching document chunks with relevance scores. "
|
||||
"Optionally filter by data sources or document types."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query",
|
||||
},
|
||||
),
|
||||
]
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"data_sources": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of filenames to filter by",
|
||||
},
|
||||
"document_types": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle search tool calls. Returns None if tool not handled."""
|
||||
if name != "openrag_search":
|
||||
return None
|
||||
|
||||
async def handle_search(arguments: dict) -> list[TextContent]:
|
||||
"""Handle openrag_search tool calls."""
|
||||
query = arguments.get("query", "")
|
||||
limit = arguments.get("limit", 10)
|
||||
score_threshold = arguments.get("score_threshold", 0)
|
||||
|
|
@ -138,3 +133,7 @@ async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] |
|
|||
except Exception as e:
|
||||
logger.error(f"Search error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
||||
|
||||
# Register the tool
|
||||
register_tool(SEARCH_TOOL, handle_search)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue