diff --git a/sdks/mcp/pyproject.toml b/sdks/mcp/pyproject.toml index 3f74a8ad..9f288f2c 100644 --- a/sdks/mcp/pyproject.toml +++ b/sdks/mcp/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openrag-mcp" -version = "0.1.0" +version = "0.1.1" description = "MCP server for OpenRAG - Chat, Search, and Ingest documents" readme = "README.md" requires-python = ">=3.10" diff --git a/sdks/mcp/src/openrag_mcp/server.py b/sdks/mcp/src/openrag_mcp/server.py index 6bfa1182..bfc21776 100644 --- a/sdks/mcp/src/openrag_mcp/server.py +++ b/sdks/mcp/src/openrag_mcp/server.py @@ -1,4 +1,5 @@ """OpenRAG MCP Server - Main server setup and entry point.""" +#TODO: utilize the SDK directly so that any changes in parameters in SDK directky reflects the MCP import asyncio import logging @@ -8,9 +9,9 @@ from mcp.server.stdio import stdio_server from mcp.types import TextContent, Tool from openrag_mcp.config import get_config -from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool -from openrag_mcp.tools.search import get_search_tools, handle_search_tool -from openrag_mcp.tools.documents import get_document_tools, handle_document_tool + +# Import tools module to trigger registration, then get registry functions +from openrag_mcp.tools import get_all_tools, get_handler # Configure logging to stderr (stdout is used for MCP protocol) logging.basicConfig( @@ -30,34 +31,17 @@ def create_server() -> Server: # Create server instance server = Server("openrag-mcp") - # Register a single list_tools handler that combines all tools @server.list_tools() async def list_all_tools() -> list[Tool]: """List all available tools.""" - tools = [] - tools.extend(get_chat_tools()) - tools.extend(get_search_tools()) - tools.extend(get_document_tools()) - return tools + return get_all_tools() - # Register a single call_tool handler that dispatches to appropriate handler @server.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: """Handle all tool calls by dispatching to the appropriate handler.""" - # Try each handler in order - result = await handle_chat_tool(name, arguments) - if result is not None: - return result - - result = await handle_search_tool(name, arguments) - if result is not None: - return result - - result = await handle_document_tool(name, arguments) - if result is not None: - return result - - # Unknown tool + handler = get_handler(name) + if handler: + return await handler(arguments) return [TextContent(type="text", text=f"Unknown tool: {name}")] logger.info("OpenRAG MCP server initialized with all tools") diff --git a/sdks/mcp/src/openrag_mcp/tools/__init__.py b/sdks/mcp/src/openrag_mcp/tools/__init__.py index da78b042..0225eac5 100644 --- a/sdks/mcp/src/openrag_mcp/tools/__init__.py +++ b/sdks/mcp/src/openrag_mcp/tools/__init__.py @@ -1,14 +1,14 @@ -"""OpenRAG MCP tools.""" +"""OpenRAG MCP tools. -from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool -from openrag_mcp.tools.search import get_search_tools, handle_search_tool -from openrag_mcp.tools.documents import get_document_tools, handle_document_tool +Import this module to register all tools with the registry. +""" -__all__ = [ - "get_chat_tools", - "handle_chat_tool", - "get_search_tools", - "handle_search_tool", - "get_document_tools", - "handle_document_tool", -] +# Import tools to trigger registration +from openrag_mcp.tools import chat # noqa: F401 +from openrag_mcp.tools import search # noqa: F401 +from openrag_mcp.tools import documents # noqa: F401 + +# Re-export registry functions for convenience +from openrag_mcp.tools.registry import get_all_tools, get_handler + +__all__ = ["get_all_tools", "get_handler"] diff --git a/sdks/mcp/src/openrag_mcp/tools/chat.py b/sdks/mcp/src/openrag_mcp/tools/chat.py index fd422e48..13199460 100644 --- a/sdks/mcp/src/openrag_mcp/tools/chat.py +++ b/sdks/mcp/src/openrag_mcp/tools/chat.py @@ -13,58 +13,53 @@ from openrag_sdk import ( ) from openrag_mcp.config import get_openrag_client +from openrag_mcp.tools.registry import register_tool logger = logging.getLogger("openrag-mcp.chat") -def get_chat_tools() -> list[Tool]: - """Return chat-related tools.""" - return [ - Tool( - name="openrag_chat", - description=( - "Send a message to OpenRAG and get a RAG-enhanced response. " - "The response is informed by documents in your knowledge base. " - "Use chat_id to continue a previous conversation, or filter_id " - "to apply a knowledge filter." - ), - inputSchema={ - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "Your question or message to send to OpenRAG", - }, - "chat_id": { - "type": "string", - "description": "Optional conversation ID to continue a previous chat", - }, - "filter_id": { - "type": "string", - "description": "Optional knowledge filter ID to apply", - }, - "limit": { - "type": "integer", - "description": "Maximum number of sources to retrieve (default: 10)", - "default": 10, - }, - "score_threshold": { - "type": "number", - "description": "Minimum relevance score threshold (default: 0)", - "default": 0, - }, - }, - "required": ["message"], +# Tool definition +CHAT_TOOL = Tool( + name="openrag_chat", + description=( + "Send a message to OpenRAG and get a RAG-enhanced response. " + "The response is informed by documents in your knowledge base. " + "Use chat_id to continue a previous conversation, or filter_id " + "to apply a knowledge filter." + ), + inputSchema={ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Your question or message to send to OpenRAG", }, - ), - ] + "chat_id": { + "type": "string", + "description": "Optional conversation ID to continue a previous chat", + }, + "filter_id": { + "type": "string", + "description": "Optional knowledge filter ID to apply", + }, + "limit": { + "type": "integer", + "description": "Maximum number of sources to retrieve (default: 10)", + "default": 10, + }, + "score_threshold": { + "type": "number", + "description": "Minimum relevance score threshold (default: 0)", + "default": 0, + }, + }, + "required": ["message"], + }, +) -async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None: - """Handle chat tool calls. Returns None if tool not handled.""" - if name != "openrag_chat": - return None - +async def handle_chat(arguments: dict) -> list[TextContent]: + """Handle openrag_chat tool calls.""" message = arguments.get("message", "") chat_id = arguments.get("chat_id") filter_id = arguments.get("filter_id") @@ -115,3 +110,7 @@ async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | No except Exception as e: logger.error(f"Chat error: {e}") return [TextContent(type="text", text=f"Error: {str(e)}")] + + +# Register the tool +register_tool(CHAT_TOOL, handle_chat) diff --git a/sdks/mcp/src/openrag_mcp/tools/documents.py b/sdks/mcp/src/openrag_mcp/tools/documents.py index ca679cd4..40e0ee61 100644 --- a/sdks/mcp/src/openrag_mcp/tools/documents.py +++ b/sdks/mcp/src/openrag_mcp/tools/documents.py @@ -15,126 +15,42 @@ from openrag_sdk import ( ) from openrag_mcp.config import get_openrag_client +from openrag_mcp.tools.registry import register_tool logger = logging.getLogger("openrag-mcp.documents") -def get_document_tools() -> list[Tool]: - """Return document-related tools.""" - return [ - Tool( - name="openrag_ingest_file", - description=( - "Ingest a local file into the OpenRAG knowledge base. " - "Supported formats: PDF, DOCX, TXT, MD, HTML, and more. " - "By default waits for ingestion to complete. Set wait=false to return immediately." - ), - inputSchema={ - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Absolute path to the file to ingest", - }, - "wait": { - "type": "boolean", - "description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.", - "default": True, - }, - }, - "required": ["file_path"], +# ============================================================================ +# Tool: openrag_ingest_file +# ============================================================================ + +INGEST_FILE_TOOL = Tool( + name="openrag_ingest_file", + description=( + "Ingest a local file into the OpenRAG knowledge base. " + "Supported formats: PDF, DOCX, TXT, MD, HTML, and more. " + "By default waits for ingestion to complete. Set wait=false to return immediately." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the file to ingest", }, - ), - Tool( - name="openrag_ingest_url", - description=( - "Ingest content from a URL into the OpenRAG knowledge base. " - "The URL content will be fetched, processed, and stored." - ), - inputSchema={ - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "The URL to fetch and ingest", - }, - }, - "required": ["url"], + "wait": { + "type": "boolean", + "description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.", + "default": True, }, - ), - Tool( - name="openrag_get_task_status", - description=( - "Check the status of an ingestion task. " - "Use the task_id returned from openrag_ingest_file when wait=false." - ), - inputSchema={ - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to check status for", - }, - }, - "required": ["task_id"], - }, - ), - Tool( - name="openrag_wait_for_task", - description=( - "Wait for an ingestion task to complete. " - "Polls the task status until it completes or fails." - ), - inputSchema={ - "type": "object", - "properties": { - "task_id": { - "type": "string", - "description": "The task ID to wait for", - }, - "timeout": { - "type": "number", - "description": "Maximum seconds to wait (default: 300)", - "default": 300, - }, - }, - "required": ["task_id"], - }, - ), - Tool( - name="openrag_delete_document", - description="Delete a document from the OpenRAG knowledge base.", - inputSchema={ - "type": "object", - "properties": { - "filename": { - "type": "string", - "description": "Name of the file to delete", - }, - }, - "required": ["filename"], - }, - ), - ] + }, + "required": ["file_path"], + }, +) -async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None: - """Handle document tool calls. Returns None if tool not handled.""" - if name == "openrag_ingest_file": - return await _ingest_file(arguments) - elif name == "openrag_ingest_url": - return await _ingest_url(arguments) - elif name == "openrag_get_task_status": - return await _get_task_status(arguments) - elif name == "openrag_wait_for_task": - return await _wait_for_task(arguments) - elif name == "openrag_delete_document": - return await _delete_document(arguments) - return None - - -async def _ingest_file(arguments: dict) -> list[TextContent]: - """Ingest a local file into OpenRAG using the SDK.""" +async def handle_ingest_file(arguments: dict) -> list[TextContent]: + """Handle openrag_ingest_file tool calls.""" file_path = arguments.get("file_path", "") wait = arguments.get("wait", True) @@ -154,11 +70,10 @@ async def _ingest_file(arguments: dict) -> list[TextContent]: response = await client.documents.ingest(file_path=path, wait=wait) if wait: - # Response is IngestTaskStatus when wait=True status = response.status successful = response.successful_files failed = response.failed_files - + if status == "completed": result = f"Successfully ingested '{path.name}'." result += f"\nStatus: {status}" @@ -170,7 +85,6 @@ async def _ingest_file(arguments: dict) -> list[TextContent]: result += f"\nSuccessful files: {successful}" result += f"\nFailed files: {failed}" else: - # Response is IngestResponse when wait=False result = f"Successfully queued '{response.filename or path.name}' for ingestion." if response.task_id: result += f"\nTask ID: {response.task_id}" @@ -201,11 +115,31 @@ async def _ingest_file(arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")] -async def _ingest_url(arguments: dict) -> list[TextContent]: - """Ingest content from a URL into OpenRAG. +# ============================================================================ +# Tool: openrag_ingest_url +# ============================================================================ - Note: This uses the SDK's chat to trigger URL ingestion via the agent. - """ +INGEST_URL_TOOL = Tool( + name="openrag_ingest_url", + description=( + "Ingest content from a URL into the OpenRAG knowledge base. " + "The URL content will be fetched, processed, and stored." + ), + inputSchema={ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The URL to fetch and ingest", + }, + }, + "required": ["url"], + }, +) + + +async def handle_ingest_url(arguments: dict) -> list[TextContent]: + """Handle openrag_ingest_url tool calls.""" url = arguments.get("url", "") if not url: @@ -215,7 +149,6 @@ async def _ingest_url(arguments: dict) -> list[TextContent]: return [TextContent(type="text", text="Error: url must start with http:// or https://")] try: - # Use chat with a special prompt to trigger URL ingestion via the agent client = get_openrag_client() response = await client.chat.create( message=f"Please ingest the content from this URL into the knowledge base: {url}", @@ -237,8 +170,31 @@ async def _ingest_url(arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")] -async def _get_task_status(arguments: dict) -> list[TextContent]: - """Get the status of an ingestion task.""" +# ============================================================================ +# Tool: openrag_get_task_status +# ============================================================================ + +GET_TASK_STATUS_TOOL = Tool( + name="openrag_get_task_status", + description=( + "Check the status of an ingestion task. " + "Use the task_id returned from openrag_ingest_file when wait=false." + ), + inputSchema={ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "The task ID to check status for", + }, + }, + "required": ["task_id"], + }, +) + + +async def handle_get_task_status(arguments: dict) -> list[TextContent]: + """Handle openrag_get_task_status tool calls.""" task_id = arguments.get("task_id", "") if not task_id: @@ -276,8 +232,36 @@ async def _get_task_status(arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error getting task status: {str(e)}")] -async def _wait_for_task(arguments: dict) -> list[TextContent]: - """Wait for an ingestion task to complete.""" +# ============================================================================ +# Tool: openrag_wait_for_task +# ============================================================================ + +WAIT_FOR_TASK_TOOL = Tool( + name="openrag_wait_for_task", + description=( + "Wait for an ingestion task to complete. " + "Polls the task status until it completes or fails." + ), + inputSchema={ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "The task ID to wait for", + }, + "timeout": { + "type": "number", + "description": "Maximum seconds to wait (default: 300)", + "default": 300, + }, + }, + "required": ["task_id"], + }, +) + + +async def handle_wait_for_task(arguments: dict) -> list[TextContent]: + """Handle openrag_wait_for_task tool calls.""" task_id = arguments.get("task_id", "") timeout = arguments.get("timeout", 300) @@ -318,8 +302,28 @@ async def _wait_for_task(arguments: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")] -async def _delete_document(arguments: dict) -> list[TextContent]: - """Delete a document from the knowledge base using the SDK.""" +# ============================================================================ +# Tool: openrag_delete_document +# ============================================================================ + +DELETE_DOCUMENT_TOOL = Tool( + name="openrag_delete_document", + description="Delete a document from the OpenRAG knowledge base.", + inputSchema={ + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Name of the file to delete", + }, + }, + "required": ["filename"], + }, +) + + +async def handle_delete_document(arguments: dict) -> list[TextContent]: + """Handle openrag_delete_document tool calls.""" filename = arguments.get("filename", "") if not filename: @@ -355,3 +359,14 @@ async def _delete_document(arguments: dict) -> list[TextContent]: except Exception as e: logger.error(f"Delete document error: {e}") return [TextContent(type="text", text=f"Error deleting document: {str(e)}")] + + +# ============================================================================ +# Register all tools +# ============================================================================ + +register_tool(INGEST_FILE_TOOL, handle_ingest_file) +register_tool(INGEST_URL_TOOL, handle_ingest_url) +register_tool(GET_TASK_STATUS_TOOL, handle_get_task_status) +register_tool(WAIT_FOR_TASK_TOOL, handle_wait_for_task) +register_tool(DELETE_DOCUMENT_TOOL, handle_delete_document) diff --git a/sdks/mcp/src/openrag_mcp/tools/registry.py b/sdks/mcp/src/openrag_mcp/tools/registry.py new file mode 100644 index 00000000..851496ab --- /dev/null +++ b/sdks/mcp/src/openrag_mcp/tools/registry.py @@ -0,0 +1,37 @@ +"""Tool registry for OpenRAG MCP server.""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from mcp.types import TextContent, Tool + +# Type alias for tool handlers +ToolHandler = Callable[[dict], Awaitable[list[TextContent]]] + + +@dataclass +class ToolEntry: + """A tool definition with its handler.""" + tool: Tool + handler: ToolHandler + + +# Global registry: tool_name -> ToolEntry +_registry: dict[str, ToolEntry] = {} + + +def register_tool(tool: Tool, handler: ToolHandler) -> None: + """Register a tool with its handler.""" + _registry[tool.name] = ToolEntry(tool=tool, handler=handler) + + +def get_all_tools() -> list[Tool]: + """Get all registered tools.""" + return [entry.tool for entry in _registry.values()] + + +def get_handler(name: str) -> ToolHandler | None: + """Get the handler for a tool by name.""" + entry = _registry.get(name) + return entry.handler if entry else None + diff --git a/sdks/mcp/src/openrag_mcp/tools/search.py b/sdks/mcp/src/openrag_mcp/tools/search.py index 3185be97..2f87c8ef 100644 --- a/sdks/mcp/src/openrag_mcp/tools/search.py +++ b/sdks/mcp/src/openrag_mcp/tools/search.py @@ -14,63 +14,58 @@ from openrag_sdk import ( ) from openrag_mcp.config import get_openrag_client +from openrag_mcp.tools.registry import register_tool logger = logging.getLogger("openrag-mcp.search") -def get_search_tools() -> list[Tool]: - """Return search-related tools.""" - return [ - Tool( - name="openrag_search", - description=( - "Search the OpenRAG knowledge base using semantic search. " - "Returns matching document chunks with relevance scores. " - "Optionally filter by data sources or document types." - ), - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query", - }, - "limit": { - "type": "integer", - "description": "Maximum number of results (default: 10)", - "default": 10, - }, - "score_threshold": { - "type": "number", - "description": "Minimum relevance score threshold (default: 0)", - "default": 0, - }, - "filter_id": { - "type": "string", - "description": "Optional knowledge filter ID to apply", - }, - "data_sources": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional list of filenames to filter by", - }, - "document_types": { - "type": "array", - "items": {"type": "string"}, - "description": "Optional list of MIME types to filter by (e.g., 'application/pdf')", - }, - }, - "required": ["query"], +# Tool definition +SEARCH_TOOL = Tool( + name="openrag_search", + description=( + "Search the OpenRAG knowledge base using semantic search. " + "Returns matching document chunks with relevance scores. " + "Optionally filter by data sources or document types." + ), + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query", }, - ), - ] + "limit": { + "type": "integer", + "description": "Maximum number of results (default: 10)", + "default": 10, + }, + "score_threshold": { + "type": "number", + "description": "Minimum relevance score threshold (default: 0)", + "default": 0, + }, + "filter_id": { + "type": "string", + "description": "Optional knowledge filter ID to apply", + }, + "data_sources": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of filenames to filter by", + }, + "document_types": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of MIME types to filter by (e.g., 'application/pdf')", + }, + }, + "required": ["query"], + }, +) -async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None: - """Handle search tool calls. Returns None if tool not handled.""" - if name != "openrag_search": - return None - +async def handle_search(arguments: dict) -> list[TextContent]: + """Handle openrag_search tool calls.""" query = arguments.get("query", "") limit = arguments.get("limit", 10) score_threshold = arguments.get("score_threshold", 0) @@ -138,3 +133,7 @@ async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | except Exception as e: logger.error(f"Search error: {e}") return [TextContent(type="text", text=f"Error: {str(e)}")] + + +# Register the tool +register_tool(SEARCH_TOOL, handle_search)