update to add tool registry

This commit is contained in:
Edwin Jose 2025-12-26 14:21:03 -05:00
parent a6a473aff0
commit 91f90d9d9d
7 changed files with 292 additions and 258 deletions

View file

@ -1,6 +1,6 @@
[project]
name = "openrag-mcp"
version = "0.1.0"
version = "0.1.1"
description = "MCP server for OpenRAG - Chat, Search, and Ingest documents"
readme = "README.md"
requires-python = ">=3.10"

View file

@ -1,4 +1,5 @@
"""OpenRAG MCP Server - Main server setup and entry point."""
#TODO: utilize the SDK directly so that any changes in parameters in SDK directky reflects the MCP
import asyncio
import logging
@ -8,9 +9,9 @@ from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from openrag_mcp.config import get_config
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
# Import tools module to trigger registration, then get registry functions
from openrag_mcp.tools import get_all_tools, get_handler
# Configure logging to stderr (stdout is used for MCP protocol)
logging.basicConfig(
@ -30,34 +31,17 @@ def create_server() -> Server:
# Create server instance
server = Server("openrag-mcp")
# Register a single list_tools handler that combines all tools
@server.list_tools()
async def list_all_tools() -> list[Tool]:
"""List all available tools."""
tools = []
tools.extend(get_chat_tools())
tools.extend(get_search_tools())
tools.extend(get_document_tools())
return tools
return get_all_tools()
# Register a single call_tool handler that dispatches to appropriate handler
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle all tool calls by dispatching to the appropriate handler."""
# Try each handler in order
result = await handle_chat_tool(name, arguments)
if result is not None:
return result
result = await handle_search_tool(name, arguments)
if result is not None:
return result
result = await handle_document_tool(name, arguments)
if result is not None:
return result
# Unknown tool
handler = get_handler(name)
if handler:
return await handler(arguments)
return [TextContent(type="text", text=f"Unknown tool: {name}")]
logger.info("OpenRAG MCP server initialized with all tools")

View file

@ -1,14 +1,14 @@
"""OpenRAG MCP tools."""
"""OpenRAG MCP tools.
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
Import this module to register all tools with the registry.
"""
__all__ = [
"get_chat_tools",
"handle_chat_tool",
"get_search_tools",
"handle_search_tool",
"get_document_tools",
"handle_document_tool",
]
# Import tools to trigger registration
from openrag_mcp.tools import chat # noqa: F401
from openrag_mcp.tools import search # noqa: F401
from openrag_mcp.tools import documents # noqa: F401
# Re-export registry functions for convenience
from openrag_mcp.tools.registry import get_all_tools, get_handler
__all__ = ["get_all_tools", "get_handler"]

View file

@ -13,14 +13,13 @@ from openrag_sdk import (
)
from openrag_mcp.config import get_openrag_client
from openrag_mcp.tools.registry import register_tool
logger = logging.getLogger("openrag-mcp.chat")
def get_chat_tools() -> list[Tool]:
"""Return chat-related tools."""
return [
Tool(
# Tool definition
CHAT_TOOL = Tool(
name="openrag_chat",
description=(
"Send a message to OpenRAG and get a RAG-enhanced response. "
@ -56,15 +55,11 @@ def get_chat_tools() -> list[Tool]:
},
"required": ["message"],
},
),
]
)
async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None:
"""Handle chat tool calls. Returns None if tool not handled."""
if name != "openrag_chat":
return None
async def handle_chat(arguments: dict) -> list[TextContent]:
"""Handle openrag_chat tool calls."""
message = arguments.get("message", "")
chat_id = arguments.get("chat_id")
filter_id = arguments.get("filter_id")
@ -115,3 +110,7 @@ async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | No
except Exception as e:
logger.error(f"Chat error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]
# Register the tool
register_tool(CHAT_TOOL, handle_chat)

View file

@ -15,14 +15,16 @@ from openrag_sdk import (
)
from openrag_mcp.config import get_openrag_client
from openrag_mcp.tools.registry import register_tool
logger = logging.getLogger("openrag-mcp.documents")
def get_document_tools() -> list[Tool]:
"""Return document-related tools."""
return [
Tool(
# ============================================================================
# Tool: openrag_ingest_file
# ============================================================================
INGEST_FILE_TOOL = Tool(
name="openrag_ingest_file",
description=(
"Ingest a local file into the OpenRAG knowledge base. "
@ -44,97 +46,11 @@ def get_document_tools() -> list[Tool]:
},
"required": ["file_path"],
},
),
Tool(
name="openrag_ingest_url",
description=(
"Ingest content from a URL into the OpenRAG knowledge base. "
"The URL content will be fetched, processed, and stored."
),
inputSchema={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch and ingest",
},
},
"required": ["url"],
},
),
Tool(
name="openrag_get_task_status",
description=(
"Check the status of an ingestion task. "
"Use the task_id returned from openrag_ingest_file when wait=false."
),
inputSchema={
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The task ID to check status for",
},
},
"required": ["task_id"],
},
),
Tool(
name="openrag_wait_for_task",
description=(
"Wait for an ingestion task to complete. "
"Polls the task status until it completes or fails."
),
inputSchema={
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The task ID to wait for",
},
"timeout": {
"type": "number",
"description": "Maximum seconds to wait (default: 300)",
"default": 300,
},
},
"required": ["task_id"],
},
),
Tool(
name="openrag_delete_document",
description="Delete a document from the OpenRAG knowledge base.",
inputSchema={
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name of the file to delete",
},
},
"required": ["filename"],
},
),
]
)
async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None:
"""Handle document tool calls. Returns None if tool not handled."""
if name == "openrag_ingest_file":
return await _ingest_file(arguments)
elif name == "openrag_ingest_url":
return await _ingest_url(arguments)
elif name == "openrag_get_task_status":
return await _get_task_status(arguments)
elif name == "openrag_wait_for_task":
return await _wait_for_task(arguments)
elif name == "openrag_delete_document":
return await _delete_document(arguments)
return None
async def _ingest_file(arguments: dict) -> list[TextContent]:
"""Ingest a local file into OpenRAG using the SDK."""
async def handle_ingest_file(arguments: dict) -> list[TextContent]:
"""Handle openrag_ingest_file tool calls."""
file_path = arguments.get("file_path", "")
wait = arguments.get("wait", True)
@ -154,7 +70,6 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
response = await client.documents.ingest(file_path=path, wait=wait)
if wait:
# Response is IngestTaskStatus when wait=True
status = response.status
successful = response.successful_files
failed = response.failed_files
@ -170,7 +85,6 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
result += f"\nSuccessful files: {successful}"
result += f"\nFailed files: {failed}"
else:
# Response is IngestResponse when wait=False
result = f"Successfully queued '{response.filename or path.name}' for ingestion."
if response.task_id:
result += f"\nTask ID: {response.task_id}"
@ -201,11 +115,31 @@ async def _ingest_file(arguments: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
async def _ingest_url(arguments: dict) -> list[TextContent]:
"""Ingest content from a URL into OpenRAG.
# ============================================================================
# Tool: openrag_ingest_url
# ============================================================================
Note: This uses the SDK's chat to trigger URL ingestion via the agent.
"""
INGEST_URL_TOOL = Tool(
name="openrag_ingest_url",
description=(
"Ingest content from a URL into the OpenRAG knowledge base. "
"The URL content will be fetched, processed, and stored."
),
inputSchema={
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch and ingest",
},
},
"required": ["url"],
},
)
async def handle_ingest_url(arguments: dict) -> list[TextContent]:
"""Handle openrag_ingest_url tool calls."""
url = arguments.get("url", "")
if not url:
@ -215,7 +149,6 @@ async def _ingest_url(arguments: dict) -> list[TextContent]:
return [TextContent(type="text", text="Error: url must start with http:// or https://")]
try:
# Use chat with a special prompt to trigger URL ingestion via the agent
client = get_openrag_client()
response = await client.chat.create(
message=f"Please ingest the content from this URL into the knowledge base: {url}",
@ -237,8 +170,31 @@ async def _ingest_url(arguments: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
async def _get_task_status(arguments: dict) -> list[TextContent]:
"""Get the status of an ingestion task."""
# ============================================================================
# Tool: openrag_get_task_status
# ============================================================================
GET_TASK_STATUS_TOOL = Tool(
name="openrag_get_task_status",
description=(
"Check the status of an ingestion task. "
"Use the task_id returned from openrag_ingest_file when wait=false."
),
inputSchema={
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The task ID to check status for",
},
},
"required": ["task_id"],
},
)
async def handle_get_task_status(arguments: dict) -> list[TextContent]:
"""Handle openrag_get_task_status tool calls."""
task_id = arguments.get("task_id", "")
if not task_id:
@ -276,8 +232,36 @@ async def _get_task_status(arguments: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Error getting task status: {str(e)}")]
async def _wait_for_task(arguments: dict) -> list[TextContent]:
"""Wait for an ingestion task to complete."""
# ============================================================================
# Tool: openrag_wait_for_task
# ============================================================================
WAIT_FOR_TASK_TOOL = Tool(
name="openrag_wait_for_task",
description=(
"Wait for an ingestion task to complete. "
"Polls the task status until it completes or fails."
),
inputSchema={
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "The task ID to wait for",
},
"timeout": {
"type": "number",
"description": "Maximum seconds to wait (default: 300)",
"default": 300,
},
},
"required": ["task_id"],
},
)
async def handle_wait_for_task(arguments: dict) -> list[TextContent]:
"""Handle openrag_wait_for_task tool calls."""
task_id = arguments.get("task_id", "")
timeout = arguments.get("timeout", 300)
@ -318,8 +302,28 @@ async def _wait_for_task(arguments: dict) -> list[TextContent]:
return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")]
async def _delete_document(arguments: dict) -> list[TextContent]:
"""Delete a document from the knowledge base using the SDK."""
# ============================================================================
# Tool: openrag_delete_document
# ============================================================================
DELETE_DOCUMENT_TOOL = Tool(
name="openrag_delete_document",
description="Delete a document from the OpenRAG knowledge base.",
inputSchema={
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name of the file to delete",
},
},
"required": ["filename"],
},
)
async def handle_delete_document(arguments: dict) -> list[TextContent]:
"""Handle openrag_delete_document tool calls."""
filename = arguments.get("filename", "")
if not filename:
@ -355,3 +359,14 @@ async def _delete_document(arguments: dict) -> list[TextContent]:
except Exception as e:
logger.error(f"Delete document error: {e}")
return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]
# ============================================================================
# Register all tools
# ============================================================================
register_tool(INGEST_FILE_TOOL, handle_ingest_file)
register_tool(INGEST_URL_TOOL, handle_ingest_url)
register_tool(GET_TASK_STATUS_TOOL, handle_get_task_status)
register_tool(WAIT_FOR_TASK_TOOL, handle_wait_for_task)
register_tool(DELETE_DOCUMENT_TOOL, handle_delete_document)

View file

@ -0,0 +1,37 @@
"""Tool registry for OpenRAG MCP server."""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from mcp.types import TextContent, Tool
# Type alias for tool handlers
ToolHandler = Callable[[dict], Awaitable[list[TextContent]]]
@dataclass
class ToolEntry:
"""A tool definition with its handler."""
tool: Tool
handler: ToolHandler
# Global registry: tool_name -> ToolEntry
_registry: dict[str, ToolEntry] = {}
def register_tool(tool: Tool, handler: ToolHandler) -> None:
"""Register a tool with its handler."""
_registry[tool.name] = ToolEntry(tool=tool, handler=handler)
def get_all_tools() -> list[Tool]:
"""Get all registered tools."""
return [entry.tool for entry in _registry.values()]
def get_handler(name: str) -> ToolHandler | None:
"""Get the handler for a tool by name."""
entry = _registry.get(name)
return entry.handler if entry else None

View file

@ -14,14 +14,13 @@ from openrag_sdk import (
)
from openrag_mcp.config import get_openrag_client
from openrag_mcp.tools.registry import register_tool
logger = logging.getLogger("openrag-mcp.search")
def get_search_tools() -> list[Tool]:
"""Return search-related tools."""
return [
Tool(
# Tool definition
SEARCH_TOOL = Tool(
name="openrag_search",
description=(
"Search the OpenRAG knowledge base using semantic search. "
@ -62,15 +61,11 @@ def get_search_tools() -> list[Tool]:
},
"required": ["query"],
},
),
]
)
async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None:
"""Handle search tool calls. Returns None if tool not handled."""
if name != "openrag_search":
return None
async def handle_search(arguments: dict) -> list[TextContent]:
"""Handle openrag_search tool calls."""
query = arguments.get("query", "")
limit = arguments.get("limit", 10)
score_threshold = arguments.get("score_threshold", 0)
@ -138,3 +133,7 @@ async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] |
except Exception as e:
logger.error(f"Search error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]
# Register the tool
register_tool(SEARCH_TOOL, handle_search)