From a6a473aff0b1f3f726cd549ae46a7a8359c520d9 Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Fri, 26 Dec 2025 13:58:21 -0500 Subject: [PATCH] update mcp tools --- sdks/mcp/src/openrag_mcp/server.py | 41 +- sdks/mcp/src/openrag_mcp/tools/__init__.py | 16 +- sdks/mcp/src/openrag_mcp/tools/chat.py | 181 ++++--- sdks/mcp/src/openrag_mcp/tools/documents.py | 552 ++++++++++++-------- sdks/mcp/src/openrag_mcp/tools/search.py | 188 ++++--- 5 files changed, 574 insertions(+), 404 deletions(-) diff --git a/sdks/mcp/src/openrag_mcp/server.py b/sdks/mcp/src/openrag_mcp/server.py index ca16ed6a..6bfa1182 100644 --- a/sdks/mcp/src/openrag_mcp/server.py +++ b/sdks/mcp/src/openrag_mcp/server.py @@ -5,11 +5,12 @@ import logging from mcp.server import Server from mcp.server.stdio import stdio_server +from mcp.types import TextContent, Tool from openrag_mcp.config import get_config -from openrag_mcp.tools.chat import register_chat_tools -# from openrag_mcp.tools.search import register_search_tools -# from openrag_mcp.tools.documents import register_document_tools +from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool +from openrag_mcp.tools.search import get_search_tools, handle_search_tool +from openrag_mcp.tools.documents import get_document_tools, handle_document_tool # Configure logging to stderr (stdout is used for MCP protocol) logging.basicConfig( @@ -29,10 +30,35 @@ def create_server() -> Server: # Create server instance server = Server("openrag-mcp") - # Register all tools - register_chat_tools(server) - # register_search_tools(server) - # register_document_tools(server) + # Register a single list_tools handler that combines all tools + @server.list_tools() + async def list_all_tools() -> list[Tool]: + """List all available tools.""" + tools = [] + tools.extend(get_chat_tools()) + tools.extend(get_search_tools()) + tools.extend(get_document_tools()) + return tools + + # Register a single call_tool handler that dispatches to appropriate handler + @server.call_tool() + async def call_tool(name: str, arguments: dict) -> list[TextContent]: + """Handle all tool calls by dispatching to the appropriate handler.""" + # Try each handler in order + result = await handle_chat_tool(name, arguments) + if result is not None: + return result + + result = await handle_search_tool(name, arguments) + if result is not None: + return result + + result = await handle_document_tool(name, arguments) + if result is not None: + return result + + # Unknown tool + return [TextContent(type="text", text=f"Unknown tool: {name}")] logger.info("OpenRAG MCP server initialized with all tools") return server @@ -68,4 +94,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/sdks/mcp/src/openrag_mcp/tools/__init__.py b/sdks/mcp/src/openrag_mcp/tools/__init__.py index 3a2fcc43..da78b042 100644 --- a/sdks/mcp/src/openrag_mcp/tools/__init__.py +++ b/sdks/mcp/src/openrag_mcp/tools/__init__.py @@ -1,8 +1,14 @@ """OpenRAG MCP tools.""" -from openrag_mcp.tools.chat import register_chat_tools -# from openrag_mcp.tools.search import register_search_tools -# from openrag_mcp.tools.documents import register_document_tools - -__all__ = ["register_chat_tools"] +from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool +from openrag_mcp.tools.search import get_search_tools, handle_search_tool +from openrag_mcp.tools.documents import get_document_tools, handle_document_tool +__all__ = [ + "get_chat_tools", + "handle_chat_tool", + "get_search_tools", + "handle_search_tool", + "get_document_tools", + "handle_document_tool", +] diff --git a/sdks/mcp/src/openrag_mcp/tools/chat.py b/sdks/mcp/src/openrag_mcp/tools/chat.py index 0d751036..fd422e48 100644 --- a/sdks/mcp/src/openrag_mcp/tools/chat.py +++ b/sdks/mcp/src/openrag_mcp/tools/chat.py @@ -2,7 +2,6 @@ import logging -from mcp.server import Server from mcp.types import TextContent, Tool from openrag_sdk import ( @@ -18,105 +17,101 @@ from openrag_mcp.config import get_openrag_client logger = logging.getLogger("openrag-mcp.chat") -def register_chat_tools(server: Server) -> None: - """Register chat-related tools with the MCP server.""" - - @server.list_tools() - async def list_chat_tools() -> list[Tool]: - """List chat tools.""" - return [ - Tool( - name="openrag_chat", - description=( - "Send a message to OpenRAG and get a RAG-enhanced response. " - "The response is informed by documents in your knowledge base. " - "Use chat_id to continue a previous conversation, or filter_id " - "to apply a knowledge filter." - ), - inputSchema={ - "type": "object", - "properties": { - "message": { - "type": "string", - "description": "Your question or message to send to OpenRAG", - }, - "chat_id": { - "type": "string", - "description": "Optional conversation ID to continue a previous chat", - }, - "filter_id": { - "type": "string", - "description": "Optional knowledge filter ID to apply", - }, - "limit": { - "type": "integer", - "description": "Maximum number of sources to retrieve (default: 10)", - "default": 10, - }, - "score_threshold": { - "type": "number", - "description": "Minimum relevance score threshold (default: 0)", - "default": 0, - }, - }, - "required": ["message"], - }, +def get_chat_tools() -> list[Tool]: + """Return chat-related tools.""" + return [ + Tool( + name="openrag_chat", + description=( + "Send a message to OpenRAG and get a RAG-enhanced response. " + "The response is informed by documents in your knowledge base. " + "Use chat_id to continue a previous conversation, or filter_id " + "to apply a knowledge filter." ), - ] + inputSchema={ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "Your question or message to send to OpenRAG", + }, + "chat_id": { + "type": "string", + "description": "Optional conversation ID to continue a previous chat", + }, + "filter_id": { + "type": "string", + "description": "Optional knowledge filter ID to apply", + }, + "limit": { + "type": "integer", + "description": "Maximum number of sources to retrieve (default: 10)", + "default": 10, + }, + "score_threshold": { + "type": "number", + "description": "Minimum relevance score threshold (default: 0)", + "default": 0, + }, + }, + "required": ["message"], + }, + ), + ] - @server.call_tool() - async def call_chat_tool(name: str, arguments: dict) -> list[TextContent]: - """Handle chat tool calls.""" - if name != "openrag_chat": - return [] - message = arguments.get("message", "") - chat_id = arguments.get("chat_id") - filter_id = arguments.get("filter_id") - limit = arguments.get("limit", 10) - score_threshold = arguments.get("score_threshold", 0) +async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None: + """Handle chat tool calls. Returns None if tool not handled.""" + if name != "openrag_chat": + return None - if not message: - return [TextContent(type="text", text="Error: message is required")] + message = arguments.get("message", "") + chat_id = arguments.get("chat_id") + filter_id = arguments.get("filter_id") + limit = arguments.get("limit", 10) + score_threshold = arguments.get("score_threshold", 0) - try: - client = get_openrag_client() - response = await client.chat.create( - message=message, - chat_id=chat_id, - filter_id=filter_id, - limit=limit, - score_threshold=score_threshold, - ) + if not message: + return [TextContent(type="text", text="Error: message is required")] - # Build formatted response - output_parts = [response.response] + try: + client = get_openrag_client() + response = await client.chat.create( + message=message, + chat_id=chat_id, + filter_id=filter_id, + limit=limit, + score_threshold=score_threshold, + ) - if response.sources: - output_parts.append("\n\n---\n**Sources:**") - for i, source in enumerate(response.sources, 1): - output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})") + # Build formatted response + output_parts = [response.response] - if response.chat_id: - output_parts.append(f"\n\n_Chat ID: {response.chat_id}_") + if response.sources: + output_parts.append("\n\n---\n**Sources:**") + for i, source in enumerate(response.sources, 1): + output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})") - return [TextContent(type="text", text="".join(output_parts))] + if response.chat_id: + output_parts.append(f"\n\n_Chat ID: {response.chat_id}_") - except AuthenticationError as e: - logger.error(f"Authentication error: {e.message}") - return [TextContent(type="text", text=f"Authentication error: {e.message}")] - except ValidationError as e: - logger.error(f"Validation error: {e.message}") - return [TextContent(type="text", text=f"Invalid request: {e.message}")] - except RateLimitError as e: - logger.error(f"Rate limit error: {e.message}") - return [TextContent(type="text", text=f"Rate limited: {e.message}")] - except ServerError as e: - logger.error(f"Server error: {e.message}") - return [TextContent(type="text", text=f"Server error: {e.message}")] - except OpenRAGError as e: - logger.error(f"OpenRAG error: {e.message}") - return [TextContent(type="text", text=f"Error: {e.message}")] - except Exception as e: - logger.error(f"Chat error: {e}") - return [TextContent(type="text", text=f"Error: {str(e)}")] + return [TextContent(type="text", text="".join(output_parts))] + + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except ValidationError as e: + logger.error(f"Validation error: {e.message}") + return [TextContent(type="text", text=f"Invalid request: {e.message}")] + except RateLimitError as e: + logger.error(f"Rate limit error: {e.message}") + return [TextContent(type="text", text=f"Rate limited: {e.message}")] + except ServerError as e: + logger.error(f"Server error: {e.message}") + return [TextContent(type="text", text=f"Server error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Chat error: {e}") + return [TextContent(type="text", text=f"Error: {str(e)}")] diff --git a/sdks/mcp/src/openrag_mcp/tools/documents.py b/sdks/mcp/src/openrag_mcp/tools/documents.py index ec3bf101..ca679cd4 100644 --- a/sdks/mcp/src/openrag_mcp/tools/documents.py +++ b/sdks/mcp/src/openrag_mcp/tools/documents.py @@ -1,257 +1,357 @@ -# """Document tools for OpenRAG MCP server.""" +"""Document tools for OpenRAG MCP server.""" -# import logging -# from pathlib import Path +import logging +from pathlib import Path -# from mcp.server import Server -# from mcp.types import TextContent, Tool +from mcp.types import TextContent, Tool -# from openrag_sdk import ( -# AuthenticationError, -# NotFoundError, -# OpenRAGError, -# RateLimitError, -# ServerError, -# ValidationError, -# ) +from openrag_sdk import ( + AuthenticationError, + NotFoundError, + OpenRAGError, + RateLimitError, + ServerError, + ValidationError, +) -# from openrag_mcp.config import get_client, get_openrag_client +from openrag_mcp.config import get_openrag_client -# logger = logging.getLogger("openrag-mcp.documents") +logger = logging.getLogger("openrag-mcp.documents") -# def register_document_tools(server: Server) -> None: -# """Register document-related tools with the MCP server.""" - -# @server.list_tools() -# async def list_document_tools() -> list[Tool]: -# """List document tools.""" -# return [ -# Tool( -# name="openrag_ingest_file", -# description=( -# "Ingest a local file into the OpenRAG knowledge base. " -# "Supported formats: PDF, DOCX, TXT, MD, HTML, and more." -# ), -# inputSchema={ -# "type": "object", -# "properties": { -# "file_path": { -# "type": "string", -# "description": "Absolute path to the file to ingest", -# }, -# }, -# "required": ["file_path"], -# }, -# ), -# Tool( -# name="openrag_ingest_url", -# description=( -# "Ingest content from a URL into the OpenRAG knowledge base. " -# "The URL content will be fetched, processed, and stored." -# ), -# inputSchema={ -# "type": "object", -# "properties": { -# "url": { -# "type": "string", -# "description": "The URL to fetch and ingest", -# }, -# }, -# "required": ["url"], -# }, -# ), -# Tool( -# name="openrag_list_documents", -# description="List documents in the OpenRAG knowledge base.", -# inputSchema={ -# "type": "object", -# "properties": { -# "limit": { -# "type": "integer", -# "description": "Maximum number of documents to return (default: 50)", -# "default": 50, -# }, -# }, -# "required": [], -# }, -# ), -# Tool( -# name="openrag_delete_document", -# description="Delete a document from the OpenRAG knowledge base.", -# inputSchema={ -# "type": "object", -# "properties": { -# "filename": { -# "type": "string", -# "description": "Name of the file to delete", -# }, -# }, -# "required": ["filename"], -# }, -# ), -# ] - -# @server.call_tool() -# async def call_document_tool(name: str, arguments: dict) -> list[TextContent]: -# """Handle document tool calls.""" -# if name == "openrag_ingest_file": -# return await _ingest_file(arguments) -# elif name == "openrag_ingest_url": -# return await _ingest_url(arguments) -# elif name == "openrag_list_documents": -# return await _list_documents(arguments) -# elif name == "openrag_delete_document": -# return await _delete_document(arguments) -# return [] +def get_document_tools() -> list[Tool]: + """Return document-related tools.""" + return [ + Tool( + name="openrag_ingest_file", + description=( + "Ingest a local file into the OpenRAG knowledge base. " + "Supported formats: PDF, DOCX, TXT, MD, HTML, and more. " + "By default waits for ingestion to complete. Set wait=false to return immediately." + ), + inputSchema={ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Absolute path to the file to ingest", + }, + "wait": { + "type": "boolean", + "description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.", + "default": True, + }, + }, + "required": ["file_path"], + }, + ), + Tool( + name="openrag_ingest_url", + description=( + "Ingest content from a URL into the OpenRAG knowledge base. " + "The URL content will be fetched, processed, and stored." + ), + inputSchema={ + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The URL to fetch and ingest", + }, + }, + "required": ["url"], + }, + ), + Tool( + name="openrag_get_task_status", + description=( + "Check the status of an ingestion task. " + "Use the task_id returned from openrag_ingest_file when wait=false." + ), + inputSchema={ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "The task ID to check status for", + }, + }, + "required": ["task_id"], + }, + ), + Tool( + name="openrag_wait_for_task", + description=( + "Wait for an ingestion task to complete. " + "Polls the task status until it completes or fails." + ), + inputSchema={ + "type": "object", + "properties": { + "task_id": { + "type": "string", + "description": "The task ID to wait for", + }, + "timeout": { + "type": "number", + "description": "Maximum seconds to wait (default: 300)", + "default": 300, + }, + }, + "required": ["task_id"], + }, + ), + Tool( + name="openrag_delete_document", + description="Delete a document from the OpenRAG knowledge base.", + inputSchema={ + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Name of the file to delete", + }, + }, + "required": ["filename"], + }, + ), + ] -# async def _ingest_file(arguments: dict) -> list[TextContent]: -# """Ingest a local file into OpenRAG using the SDK.""" -# file_path = arguments.get("file_path", "") - -# if not file_path: -# return [TextContent(type="text", text="Error: file_path is required")] - -# path = Path(file_path) - -# if not path.exists(): -# return [TextContent(type="text", text=f"Error: File not found: {file_path}")] - -# if not path.is_file(): -# return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")] - -# try: -# client = get_openrag_client() -# # Use wait=False to return immediately with task_id -# response = await client.documents.ingest(file_path=path, wait=False) - -# result = f"Successfully queued '{response.filename or path.name}' for ingestion." -# if response.task_id: -# result += f"\nTask ID: {response.task_id}" - -# return [TextContent(type="text", text=result)] - -# except AuthenticationError as e: -# logger.error(f"Authentication error: {e.message}") -# return [TextContent(type="text", text=f"Authentication error: {e.message}")] -# except ValidationError as e: -# logger.error(f"Validation error: {e.message}") -# return [TextContent(type="text", text=f"Invalid request: {e.message}")] -# except RateLimitError as e: -# logger.error(f"Rate limit error: {e.message}") -# return [TextContent(type="text", text=f"Rate limited: {e.message}")] -# except ServerError as e: -# logger.error(f"Server error: {e.message}") -# return [TextContent(type="text", text=f"Server error: {e.message}")] -# except OpenRAGError as e: -# logger.error(f"OpenRAG error: {e.message}") -# return [TextContent(type="text", text=f"Error: {e.message}")] -# except Exception as e: -# logger.error(f"Ingest file error: {e}") -# return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")] +async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None: + """Handle document tool calls. Returns None if tool not handled.""" + if name == "openrag_ingest_file": + return await _ingest_file(arguments) + elif name == "openrag_ingest_url": + return await _ingest_url(arguments) + elif name == "openrag_get_task_status": + return await _get_task_status(arguments) + elif name == "openrag_wait_for_task": + return await _wait_for_task(arguments) + elif name == "openrag_delete_document": + return await _delete_document(arguments) + return None -# async def _ingest_url(arguments: dict) -> list[TextContent]: -# """Ingest content from a URL into OpenRAG. +async def _ingest_file(arguments: dict) -> list[TextContent]: + """Ingest a local file into OpenRAG using the SDK.""" + file_path = arguments.get("file_path", "") + wait = arguments.get("wait", True) -# Note: This uses the SDK's chat to trigger URL ingestion via the agent. -# """ -# url = arguments.get("url", "") + if not file_path: + return [TextContent(type="text", text="Error: file_path is required")] -# if not url: -# return [TextContent(type="text", text="Error: url is required")] + path = Path(file_path) -# if not url.startswith(("http://", "https://")): -# return [TextContent(type="text", text="Error: url must start with http:// or https://")] + if not path.exists(): + return [TextContent(type="text", text=f"Error: File not found: {file_path}")] -# try: -# # Use chat with a special prompt to trigger URL ingestion via the agent -# client = get_openrag_client() -# response = await client.chat.create( -# message=f"Please ingest the content from this URL into the knowledge base: {url}", -# ) + if not path.is_file(): + return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")] -# return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")] + try: + client = get_openrag_client() + response = await client.documents.ingest(file_path=path, wait=wait) -# except AuthenticationError as e: -# logger.error(f"Authentication error: {e.message}") -# return [TextContent(type="text", text=f"Authentication error: {e.message}")] -# except ServerError as e: -# logger.error(f"Server error: {e.message}") -# return [TextContent(type="text", text=f"Server error: {e.message}")] -# except OpenRAGError as e: -# logger.error(f"OpenRAG error: {e.message}") -# return [TextContent(type="text", text=f"Error: {e.message}")] -# except Exception as e: -# logger.error(f"Ingest URL error: {e}") -# return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")] + if wait: + # Response is IngestTaskStatus when wait=True + status = response.status + successful = response.successful_files + failed = response.failed_files + + if status == "completed": + result = f"Successfully ingested '{path.name}'." + result += f"\nStatus: {status}" + result += f"\nSuccessful files: {successful}" + if failed > 0: + result += f"\nFailed files: {failed}" + else: + result = f"Ingestion finished with status: {status}" + result += f"\nSuccessful files: {successful}" + result += f"\nFailed files: {failed}" + else: + # Response is IngestResponse when wait=False + result = f"Successfully queued '{response.filename or path.name}' for ingestion." + if response.task_id: + result += f"\nTask ID: {response.task_id}" + result += "\n\nUse openrag_get_task_status or openrag_wait_for_task to check progress." + + return [TextContent(type="text", text=result)] + + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except ValidationError as e: + logger.error(f"Validation error: {e.message}") + return [TextContent(type="text", text=f"Invalid request: {e.message}")] + except RateLimitError as e: + logger.error(f"Rate limit error: {e.message}") + return [TextContent(type="text", text=f"Rate limited: {e.message}")] + except ServerError as e: + logger.error(f"Server error: {e.message}") + return [TextContent(type="text", text=f"Server error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except TimeoutError as e: + logger.error(f"Timeout error: {e}") + return [TextContent(type="text", text=f"Ingestion timed out: {str(e)}")] + except Exception as e: + logger.error(f"Ingest file error: {e}") + return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")] -# async def _list_documents(arguments: dict) -> list[TextContent]: -# """List documents in the knowledge base. +async def _ingest_url(arguments: dict) -> list[TextContent]: + """Ingest content from a URL into OpenRAG. -# Note: This uses direct HTTP calls as the SDK doesn't yet support listing documents. -# """ -# limit = arguments.get("limit", 50) + Note: This uses the SDK's chat to trigger URL ingestion via the agent. + """ + url = arguments.get("url", "") -# try: -# async with get_client() as client: -# response = await client.get("/api/v1/documents", params={"limit": limit}) -# response.raise_for_status() -# data = response.json() + if not url: + return [TextContent(type="text", text="Error: url is required")] -# documents = data.get("documents", []) + if not url.startswith(("http://", "https://")): + return [TextContent(type="text", text="Error: url must start with http:// or https://")] -# if not documents: -# return [TextContent(type="text", text="No documents found in the knowledge base.")] + try: + # Use chat with a special prompt to trigger URL ingestion via the agent + client = get_openrag_client() + response = await client.chat.create( + message=f"Please ingest the content from this URL into the knowledge base: {url}", + ) -# output_parts = [f"Found {len(documents)} document(s):\n"] + return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")] -# for doc in documents: -# filename = doc.get("filename", "Unknown") -# chunks = doc.get("chunk_count", 0) -# created = doc.get("created_at", "") - -# output_parts.append(f"\n- **{filename}** ({chunks} chunks)") -# if created: -# output_parts.append(f" - Added: {created[:10]}") - -# return [TextContent(type="text", text="".join(output_parts))] - -# except Exception as e: -# logger.error(f"List documents error: {e}") -# return [TextContent(type="text", text=f"Error listing documents: {str(e)}")] + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except ServerError as e: + logger.error(f"Server error: {e.message}") + return [TextContent(type="text", text=f"Server error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Ingest URL error: {e}") + return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")] -# async def _delete_document(arguments: dict) -> list[TextContent]: -# """Delete a document from the knowledge base using the SDK.""" -# filename = arguments.get("filename", "") +async def _get_task_status(arguments: dict) -> list[TextContent]: + """Get the status of an ingestion task.""" + task_id = arguments.get("task_id", "") -# if not filename: -# return [TextContent(type="text", text="Error: filename is required")] + if not task_id: + return [TextContent(type="text", text="Error: task_id is required")] -# try: -# client = get_openrag_client() -# response = await client.documents.delete(filename) + try: + client = get_openrag_client() + status = await client.documents.get_task_status(task_id) -# return [TextContent( -# type="text", -# text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).", -# )] + output_parts = [f"**Task Status: {status.status}**"] + output_parts.append(f"\nTask ID: {status.task_id}") + output_parts.append(f"\nTotal files: {status.total_files}") + output_parts.append(f"\nProcessed: {status.processed_files}") + output_parts.append(f"\nSuccessful: {status.successful_files}") + output_parts.append(f"\nFailed: {status.failed_files}") -# except NotFoundError as e: -# logger.error(f"Document not found: {e.message}") -# return [TextContent(type="text", text=f"Document not found: {e.message}")] -# except AuthenticationError as e: -# logger.error(f"Authentication error: {e.message}") -# return [TextContent(type="text", text=f"Authentication error: {e.message}")] -# except ServerError as e: -# logger.error(f"Server error: {e.message}") -# return [TextContent(type="text", text=f"Server error: {e.message}")] -# except OpenRAGError as e: -# logger.error(f"OpenRAG error: {e.message}") -# return [TextContent(type="text", text=f"Error: {e.message}")] -# except Exception as e: -# logger.error(f"Delete document error: {e}") -# return [TextContent(type="text", text=f"Error deleting document: {str(e)}")] + if status.files: + output_parts.append("\n\n**File Details:**") + for filename, file_status in status.files.items(): + output_parts.append(f"\n- {filename}: {file_status}") + + return [TextContent(type="text", text="".join(output_parts))] + + except NotFoundError as e: + logger.error(f"Task not found: {e.message}") + return [TextContent(type="text", text=f"Task not found: {e.message}")] + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Get task status error: {e}") + return [TextContent(type="text", text=f"Error getting task status: {str(e)}")] + + +async def _wait_for_task(arguments: dict) -> list[TextContent]: + """Wait for an ingestion task to complete.""" + task_id = arguments.get("task_id", "") + timeout = arguments.get("timeout", 300) + + if not task_id: + return [TextContent(type="text", text="Error: task_id is required")] + + try: + client = get_openrag_client() + status = await client.documents.wait_for_task(task_id, timeout=timeout) + + output_parts = [f"**Task Completed: {status.status}**"] + output_parts.append(f"\nTask ID: {status.task_id}") + output_parts.append(f"\nTotal files: {status.total_files}") + output_parts.append(f"\nSuccessful: {status.successful_files}") + output_parts.append(f"\nFailed: {status.failed_files}") + + if status.files: + output_parts.append("\n\n**File Details:**") + for filename, file_status in status.files.items(): + output_parts.append(f"\n- {filename}: {file_status}") + + return [TextContent(type="text", text="".join(output_parts))] + + except TimeoutError as e: + logger.error(f"Wait for task timeout: {e}") + return [TextContent(type="text", text=f"Task did not complete within {timeout} seconds.")] + except NotFoundError as e: + logger.error(f"Task not found: {e.message}") + return [TextContent(type="text", text=f"Task not found: {e.message}")] + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Wait for task error: {e}") + return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")] + + +async def _delete_document(arguments: dict) -> list[TextContent]: + """Delete a document from the knowledge base using the SDK.""" + filename = arguments.get("filename", "") + + if not filename: + return [TextContent(type="text", text="Error: filename is required")] + + try: + client = get_openrag_client() + response = await client.documents.delete(filename) + + if response.success: + return [TextContent( + type="text", + text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).", + )] + else: + return [TextContent( + type="text", + text=f"Failed to delete '{filename}'.", + )] + + except NotFoundError as e: + logger.error(f"Document not found: {e.message}") + return [TextContent(type="text", text=f"Document not found: {e.message}")] + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except ServerError as e: + logger.error(f"Server error: {e.message}") + return [TextContent(type="text", text=f"Server error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Delete document error: {e}") + return [TextContent(type="text", text=f"Error deleting document: {str(e)}")] diff --git a/sdks/mcp/src/openrag_mcp/tools/search.py b/sdks/mcp/src/openrag_mcp/tools/search.py index 0a2a05ae..3185be97 100644 --- a/sdks/mcp/src/openrag_mcp/tools/search.py +++ b/sdks/mcp/src/openrag_mcp/tools/search.py @@ -2,95 +2,139 @@ import logging -from mcp.server import Server from mcp.types import TextContent, Tool -from openrag_mcp.config import get_client +from openrag_sdk import ( + AuthenticationError, + OpenRAGError, + RateLimitError, + SearchFilters, + ServerError, + ValidationError, +) + +from openrag_mcp.config import get_openrag_client logger = logging.getLogger("openrag-mcp.search") -def register_search_tools(server: Server) -> None: - """Register search-related tools with the MCP server.""" - - @server.list_tools() - async def list_search_tools() -> list[Tool]: - """List search tools.""" - return [ - Tool( - name="openrag_search", - description=( - "Search the OpenRAG knowledge base using semantic search. " - "Returns matching document chunks with relevance scores." - ), - inputSchema={ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "The search query", - }, - "limit": { - "type": "integer", - "description": "Maximum number of results (default: 10)", - "default": 10, - }, - }, - "required": ["query"], - }, +def get_search_tools() -> list[Tool]: + """Return search-related tools.""" + return [ + Tool( + name="openrag_search", + description=( + "Search the OpenRAG knowledge base using semantic search. " + "Returns matching document chunks with relevance scores. " + "Optionally filter by data sources or document types." ), - ] + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query", + }, + "limit": { + "type": "integer", + "description": "Maximum number of results (default: 10)", + "default": 10, + }, + "score_threshold": { + "type": "number", + "description": "Minimum relevance score threshold (default: 0)", + "default": 0, + }, + "filter_id": { + "type": "string", + "description": "Optional knowledge filter ID to apply", + }, + "data_sources": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of filenames to filter by", + }, + "document_types": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional list of MIME types to filter by (e.g., 'application/pdf')", + }, + }, + "required": ["query"], + }, + ), + ] - @server.call_tool() - async def call_search_tool(name: str, arguments: dict) -> list[TextContent]: - """Handle search tool calls.""" - if name != "openrag_search": - return [] - query = arguments.get("query", "") - limit = arguments.get("limit", 10) +async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None: + """Handle search tool calls. Returns None if tool not handled.""" + if name != "openrag_search": + return None - if not query: - return [TextContent(type="text", text="Error: query is required")] + query = arguments.get("query", "") + limit = arguments.get("limit", 10) + score_threshold = arguments.get("score_threshold", 0) + filter_id = arguments.get("filter_id") + data_sources = arguments.get("data_sources") + document_types = arguments.get("document_types") - try: - async with get_client() as client: - payload = { - "query": query, - "limit": limit, - } + if not query: + return [TextContent(type="text", text="Error: query is required")] - response = await client.post("/api/v1/search", json=payload) - response.raise_for_status() - data = response.json() + try: + client = get_openrag_client() - results = data.get("results", []) + # Build filters if provided + filters = None + if data_sources or document_types: + filters = SearchFilters( + data_sources=data_sources, + document_types=document_types, + ) - if not results: - return [TextContent(type="text", text="No results found.")] + response = await client.search.query( + query=query, + limit=limit, + score_threshold=score_threshold, + filter_id=filter_id, + filters=filters, + ) - # Format results - output_parts = [f"Found {len(results)} result(s):\n"] + if not response.results: + return [TextContent(type="text", text="No results found.")] - for i, result in enumerate(results, 1): - filename = result.get("filename", "Unknown") - score = result.get("score", 0) - content = result.get("content", "") - page = result.get("page_number") + # Format results + output_parts = [f"Found {len(response.results)} result(s):\n"] - output_parts.append(f"\n---\n**{i}. {filename}**") - if page: - output_parts.append(f" (page {page})") - output_parts.append(f"\nRelevance: {score:.2f}\n") + for i, result in enumerate(response.results, 1): + output_parts.append(f"\n---\n**{i}. {result.filename}**") + if result.page: + output_parts.append(f" (page {result.page})") + output_parts.append(f"\nRelevance: {result.score:.2f}\n") - # Truncate long content - if len(content) > 500: - content = content[:500] + "..." - output_parts.append(f"\n{content}\n") + # Truncate long content + content = result.text + if len(content) > 500: + content = content[:500] + "..." + output_parts.append(f"\n{content}\n") - return [TextContent(type="text", text="".join(output_parts))] - - except Exception as e: - logger.error(f"Search error: {e}") - return [TextContent(type="text", text=f"Error: {str(e)}")] + return [TextContent(type="text", text="".join(output_parts))] + except AuthenticationError as e: + logger.error(f"Authentication error: {e.message}") + return [TextContent(type="text", text=f"Authentication error: {e.message}")] + except ValidationError as e: + logger.error(f"Validation error: {e.message}") + return [TextContent(type="text", text=f"Invalid request: {e.message}")] + except RateLimitError as e: + logger.error(f"Rate limit error: {e.message}") + return [TextContent(type="text", text=f"Rate limited: {e.message}")] + except ServerError as e: + logger.error(f"Server error: {e.message}") + return [TextContent(type="text", text=f"Server error: {e.message}")] + except OpenRAGError as e: + logger.error(f"OpenRAG error: {e.message}") + return [TextContent(type="text", text=f"Error: {e.message}")] + except Exception as e: + logger.error(f"Search error: {e}") + return [TextContent(type="text", text=f"Error: {str(e)}")]