update mcp tools

This commit is contained in:
Edwin Jose 2025-12-26 13:58:21 -05:00
parent cf8750d906
commit a6a473aff0
5 changed files with 574 additions and 404 deletions

View file

@ -5,11 +5,12 @@ import logging
from mcp.server import Server from mcp.server import Server
from mcp.server.stdio import stdio_server from mcp.server.stdio import stdio_server
from mcp.types import TextContent, Tool
from openrag_mcp.config import get_config from openrag_mcp.config import get_config
from openrag_mcp.tools.chat import register_chat_tools from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
# from openrag_mcp.tools.search import register_search_tools from openrag_mcp.tools.search import get_search_tools, handle_search_tool
# from openrag_mcp.tools.documents import register_document_tools from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
# Configure logging to stderr (stdout is used for MCP protocol) # Configure logging to stderr (stdout is used for MCP protocol)
logging.basicConfig( logging.basicConfig(
@ -29,10 +30,35 @@ def create_server() -> Server:
# Create server instance # Create server instance
server = Server("openrag-mcp") server = Server("openrag-mcp")
# Register all tools # Register a single list_tools handler that combines all tools
register_chat_tools(server) @server.list_tools()
# register_search_tools(server) async def list_all_tools() -> list[Tool]:
# register_document_tools(server) """List all available tools."""
tools = []
tools.extend(get_chat_tools())
tools.extend(get_search_tools())
tools.extend(get_document_tools())
return tools
# Register a single call_tool handler that dispatches to appropriate handler
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle all tool calls by dispatching to the appropriate handler."""
# Try each handler in order
result = await handle_chat_tool(name, arguments)
if result is not None:
return result
result = await handle_search_tool(name, arguments)
if result is not None:
return result
result = await handle_document_tool(name, arguments)
if result is not None:
return result
# Unknown tool
return [TextContent(type="text", text=f"Unknown tool: {name}")]
logger.info("OpenRAG MCP server initialized with all tools") logger.info("OpenRAG MCP server initialized with all tools")
return server return server
@ -68,4 +94,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -1,8 +1,14 @@
"""OpenRAG MCP tools.""" """OpenRAG MCP tools."""
from openrag_mcp.tools.chat import register_chat_tools from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
# from openrag_mcp.tools.search import register_search_tools from openrag_mcp.tools.search import get_search_tools, handle_search_tool
# from openrag_mcp.tools.documents import register_document_tools from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
__all__ = ["register_chat_tools"]
__all__ = [
"get_chat_tools",
"handle_chat_tool",
"get_search_tools",
"handle_search_tool",
"get_document_tools",
"handle_document_tool",
]

View file

@ -2,7 +2,6 @@
import logging import logging
from mcp.server import Server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from openrag_sdk import ( from openrag_sdk import (
@ -18,105 +17,101 @@ from openrag_mcp.config import get_openrag_client
logger = logging.getLogger("openrag-mcp.chat") logger = logging.getLogger("openrag-mcp.chat")
def register_chat_tools(server: Server) -> None: def get_chat_tools() -> list[Tool]:
"""Register chat-related tools with the MCP server.""" """Return chat-related tools."""
return [
@server.list_tools() Tool(
async def list_chat_tools() -> list[Tool]: name="openrag_chat",
"""List chat tools.""" description=(
return [ "Send a message to OpenRAG and get a RAG-enhanced response. "
Tool( "The response is informed by documents in your knowledge base. "
name="openrag_chat", "Use chat_id to continue a previous conversation, or filter_id "
description=( "to apply a knowledge filter."
"Send a message to OpenRAG and get a RAG-enhanced response. "
"The response is informed by documents in your knowledge base. "
"Use chat_id to continue a previous conversation, or filter_id "
"to apply a knowledge filter."
),
inputSchema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Your question or message to send to OpenRAG",
},
"chat_id": {
"type": "string",
"description": "Optional conversation ID to continue a previous chat",
},
"filter_id": {
"type": "string",
"description": "Optional knowledge filter ID to apply",
},
"limit": {
"type": "integer",
"description": "Maximum number of sources to retrieve (default: 10)",
"default": 10,
},
"score_threshold": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0)",
"default": 0,
},
},
"required": ["message"],
},
), ),
] inputSchema={
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Your question or message to send to OpenRAG",
},
"chat_id": {
"type": "string",
"description": "Optional conversation ID to continue a previous chat",
},
"filter_id": {
"type": "string",
"description": "Optional knowledge filter ID to apply",
},
"limit": {
"type": "integer",
"description": "Maximum number of sources to retrieve (default: 10)",
"default": 10,
},
"score_threshold": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0)",
"default": 0,
},
},
"required": ["message"],
},
),
]
@server.call_tool()
async def call_chat_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle chat tool calls."""
if name != "openrag_chat":
return []
message = arguments.get("message", "") async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None:
chat_id = arguments.get("chat_id") """Handle chat tool calls. Returns None if tool not handled."""
filter_id = arguments.get("filter_id") if name != "openrag_chat":
limit = arguments.get("limit", 10) return None
score_threshold = arguments.get("score_threshold", 0)
if not message: message = arguments.get("message", "")
return [TextContent(type="text", text="Error: message is required")] chat_id = arguments.get("chat_id")
filter_id = arguments.get("filter_id")
limit = arguments.get("limit", 10)
score_threshold = arguments.get("score_threshold", 0)
try: if not message:
client = get_openrag_client() return [TextContent(type="text", text="Error: message is required")]
response = await client.chat.create(
message=message,
chat_id=chat_id,
filter_id=filter_id,
limit=limit,
score_threshold=score_threshold,
)
# Build formatted response try:
output_parts = [response.response] client = get_openrag_client()
response = await client.chat.create(
message=message,
chat_id=chat_id,
filter_id=filter_id,
limit=limit,
score_threshold=score_threshold,
)
if response.sources: # Build formatted response
output_parts.append("\n\n---\n**Sources:**") output_parts = [response.response]
for i, source in enumerate(response.sources, 1):
output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
if response.chat_id: if response.sources:
output_parts.append(f"\n\n_Chat ID: {response.chat_id}_") output_parts.append("\n\n---\n**Sources:**")
for i, source in enumerate(response.sources, 1):
output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
return [TextContent(type="text", text="".join(output_parts))] if response.chat_id:
output_parts.append(f"\n\n_Chat ID: {response.chat_id}_")
except AuthenticationError as e: return [TextContent(type="text", text="".join(output_parts))]
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")] except AuthenticationError as e:
except ValidationError as e: logger.error(f"Authentication error: {e.message}")
logger.error(f"Validation error: {e.message}") return [TextContent(type="text", text=f"Authentication error: {e.message}")]
return [TextContent(type="text", text=f"Invalid request: {e.message}")] except ValidationError as e:
except RateLimitError as e: logger.error(f"Validation error: {e.message}")
logger.error(f"Rate limit error: {e.message}") return [TextContent(type="text", text=f"Invalid request: {e.message}")]
return [TextContent(type="text", text=f"Rate limited: {e.message}")] except RateLimitError as e:
except ServerError as e: logger.error(f"Rate limit error: {e.message}")
logger.error(f"Server error: {e.message}") return [TextContent(type="text", text=f"Rate limited: {e.message}")]
return [TextContent(type="text", text=f"Server error: {e.message}")] except ServerError as e:
except OpenRAGError as e: logger.error(f"Server error: {e.message}")
logger.error(f"OpenRAG error: {e.message}") return [TextContent(type="text", text=f"Server error: {e.message}")]
return [TextContent(type="text", text=f"Error: {e.message}")] except OpenRAGError as e:
except Exception as e: logger.error(f"OpenRAG error: {e.message}")
logger.error(f"Chat error: {e}") return [TextContent(type="text", text=f"Error: {e.message}")]
return [TextContent(type="text", text=f"Error: {str(e)}")] except Exception as e:
logger.error(f"Chat error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]

View file

@ -1,257 +1,357 @@
# """Document tools for OpenRAG MCP server.""" """Document tools for OpenRAG MCP server."""
# import logging import logging
# from pathlib import Path from pathlib import Path
# from mcp.server import Server from mcp.types import TextContent, Tool
# from mcp.types import TextContent, Tool
# from openrag_sdk import ( from openrag_sdk import (
# AuthenticationError, AuthenticationError,
# NotFoundError, NotFoundError,
# OpenRAGError, OpenRAGError,
# RateLimitError, RateLimitError,
# ServerError, ServerError,
# ValidationError, ValidationError,
# ) )
# from openrag_mcp.config import get_client, get_openrag_client from openrag_mcp.config import get_openrag_client
# logger = logging.getLogger("openrag-mcp.documents") logger = logging.getLogger("openrag-mcp.documents")
# def register_document_tools(server: Server) -> None: def get_document_tools() -> list[Tool]:
# """Register document-related tools with the MCP server.""" """Return document-related tools."""
return [
# @server.list_tools() Tool(
# async def list_document_tools() -> list[Tool]: name="openrag_ingest_file",
# """List document tools.""" description=(
# return [ "Ingest a local file into the OpenRAG knowledge base. "
# Tool( "Supported formats: PDF, DOCX, TXT, MD, HTML, and more. "
# name="openrag_ingest_file", "By default waits for ingestion to complete. Set wait=false to return immediately."
# description=( ),
# "Ingest a local file into the OpenRAG knowledge base. " inputSchema={
# "Supported formats: PDF, DOCX, TXT, MD, HTML, and more." "type": "object",
# ), "properties": {
# inputSchema={ "file_path": {
# "type": "object", "type": "string",
# "properties": { "description": "Absolute path to the file to ingest",
# "file_path": { },
# "type": "string", "wait": {
# "description": "Absolute path to the file to ingest", "type": "boolean",
# }, "description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.",
# }, "default": True,
# "required": ["file_path"], },
# }, },
# ), "required": ["file_path"],
# Tool( },
# name="openrag_ingest_url", ),
# description=( Tool(
# "Ingest content from a URL into the OpenRAG knowledge base. " name="openrag_ingest_url",
# "The URL content will be fetched, processed, and stored." description=(
# ), "Ingest content from a URL into the OpenRAG knowledge base. "
# inputSchema={ "The URL content will be fetched, processed, and stored."
# "type": "object", ),
# "properties": { inputSchema={
# "url": { "type": "object",
# "type": "string", "properties": {
# "description": "The URL to fetch and ingest", "url": {
# }, "type": "string",
# }, "description": "The URL to fetch and ingest",
# "required": ["url"], },
# }, },
# ), "required": ["url"],
# Tool( },
# name="openrag_list_documents", ),
# description="List documents in the OpenRAG knowledge base.", Tool(
# inputSchema={ name="openrag_get_task_status",
# "type": "object", description=(
# "properties": { "Check the status of an ingestion task. "
# "limit": { "Use the task_id returned from openrag_ingest_file when wait=false."
# "type": "integer", ),
# "description": "Maximum number of documents to return (default: 50)", inputSchema={
# "default": 50, "type": "object",
# }, "properties": {
# }, "task_id": {
# "required": [], "type": "string",
# }, "description": "The task ID to check status for",
# ), },
# Tool( },
# name="openrag_delete_document", "required": ["task_id"],
# description="Delete a document from the OpenRAG knowledge base.", },
# inputSchema={ ),
# "type": "object", Tool(
# "properties": { name="openrag_wait_for_task",
# "filename": { description=(
# "type": "string", "Wait for an ingestion task to complete. "
# "description": "Name of the file to delete", "Polls the task status until it completes or fails."
# }, ),
# }, inputSchema={
# "required": ["filename"], "type": "object",
# }, "properties": {
# ), "task_id": {
# ] "type": "string",
"description": "The task ID to wait for",
# @server.call_tool() },
# async def call_document_tool(name: str, arguments: dict) -> list[TextContent]: "timeout": {
# """Handle document tool calls.""" "type": "number",
# if name == "openrag_ingest_file": "description": "Maximum seconds to wait (default: 300)",
# return await _ingest_file(arguments) "default": 300,
# elif name == "openrag_ingest_url": },
# return await _ingest_url(arguments) },
# elif name == "openrag_list_documents": "required": ["task_id"],
# return await _list_documents(arguments) },
# elif name == "openrag_delete_document": ),
# return await _delete_document(arguments) Tool(
# return [] name="openrag_delete_document",
description="Delete a document from the OpenRAG knowledge base.",
inputSchema={
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name of the file to delete",
},
},
"required": ["filename"],
},
),
]
# async def _ingest_file(arguments: dict) -> list[TextContent]: async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None:
# """Ingest a local file into OpenRAG using the SDK.""" """Handle document tool calls. Returns None if tool not handled."""
# file_path = arguments.get("file_path", "") if name == "openrag_ingest_file":
return await _ingest_file(arguments)
# if not file_path: elif name == "openrag_ingest_url":
# return [TextContent(type="text", text="Error: file_path is required")] return await _ingest_url(arguments)
elif name == "openrag_get_task_status":
# path = Path(file_path) return await _get_task_status(arguments)
elif name == "openrag_wait_for_task":
# if not path.exists(): return await _wait_for_task(arguments)
# return [TextContent(type="text", text=f"Error: File not found: {file_path}")] elif name == "openrag_delete_document":
return await _delete_document(arguments)
# if not path.is_file(): return None
# return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")]
# try:
# client = get_openrag_client()
# # Use wait=False to return immediately with task_id
# response = await client.documents.ingest(file_path=path, wait=False)
# result = f"Successfully queued '{response.filename or path.name}' for ingestion."
# if response.task_id:
# result += f"\nTask ID: {response.task_id}"
# return [TextContent(type="text", text=result)]
# except AuthenticationError as e:
# logger.error(f"Authentication error: {e.message}")
# return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# except ValidationError as e:
# logger.error(f"Validation error: {e.message}")
# return [TextContent(type="text", text=f"Invalid request: {e.message}")]
# except RateLimitError as e:
# logger.error(f"Rate limit error: {e.message}")
# return [TextContent(type="text", text=f"Rate limited: {e.message}")]
# except ServerError as e:
# logger.error(f"Server error: {e.message}")
# return [TextContent(type="text", text=f"Server error: {e.message}")]
# except OpenRAGError as e:
# logger.error(f"OpenRAG error: {e.message}")
# return [TextContent(type="text", text=f"Error: {e.message}")]
# except Exception as e:
# logger.error(f"Ingest file error: {e}")
# return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
# async def _ingest_url(arguments: dict) -> list[TextContent]: async def _ingest_file(arguments: dict) -> list[TextContent]:
# """Ingest content from a URL into OpenRAG. """Ingest a local file into OpenRAG using the SDK."""
file_path = arguments.get("file_path", "")
wait = arguments.get("wait", True)
# Note: This uses the SDK's chat to trigger URL ingestion via the agent. if not file_path:
# """ return [TextContent(type="text", text="Error: file_path is required")]
# url = arguments.get("url", "")
# if not url: path = Path(file_path)
# return [TextContent(type="text", text="Error: url is required")]
# if not url.startswith(("http://", "https://")): if not path.exists():
# return [TextContent(type="text", text="Error: url must start with http:// or https://")] return [TextContent(type="text", text=f"Error: File not found: {file_path}")]
# try: if not path.is_file():
# # Use chat with a special prompt to trigger URL ingestion via the agent return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")]
# client = get_openrag_client()
# response = await client.chat.create(
# message=f"Please ingest the content from this URL into the knowledge base: {url}",
# )
# return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")] try:
client = get_openrag_client()
response = await client.documents.ingest(file_path=path, wait=wait)
# except AuthenticationError as e: if wait:
# logger.error(f"Authentication error: {e.message}") # Response is IngestTaskStatus when wait=True
# return [TextContent(type="text", text=f"Authentication error: {e.message}")] status = response.status
# except ServerError as e: successful = response.successful_files
# logger.error(f"Server error: {e.message}") failed = response.failed_files
# return [TextContent(type="text", text=f"Server error: {e.message}")]
# except OpenRAGError as e: if status == "completed":
# logger.error(f"OpenRAG error: {e.message}") result = f"Successfully ingested '{path.name}'."
# return [TextContent(type="text", text=f"Error: {e.message}")] result += f"\nStatus: {status}"
# except Exception as e: result += f"\nSuccessful files: {successful}"
# logger.error(f"Ingest URL error: {e}") if failed > 0:
# return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")] result += f"\nFailed files: {failed}"
else:
result = f"Ingestion finished with status: {status}"
result += f"\nSuccessful files: {successful}"
result += f"\nFailed files: {failed}"
else:
# Response is IngestResponse when wait=False
result = f"Successfully queued '{response.filename or path.name}' for ingestion."
if response.task_id:
result += f"\nTask ID: {response.task_id}"
result += "\n\nUse openrag_get_task_status or openrag_wait_for_task to check progress."
return [TextContent(type="text", text=result)]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except ValidationError as e:
logger.error(f"Validation error: {e.message}")
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
except RateLimitError as e:
logger.error(f"Rate limit error: {e.message}")
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
except ServerError as e:
logger.error(f"Server error: {e.message}")
return [TextContent(type="text", text=f"Server error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except TimeoutError as e:
logger.error(f"Timeout error: {e}")
return [TextContent(type="text", text=f"Ingestion timed out: {str(e)}")]
except Exception as e:
logger.error(f"Ingest file error: {e}")
return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
# async def _list_documents(arguments: dict) -> list[TextContent]: async def _ingest_url(arguments: dict) -> list[TextContent]:
# """List documents in the knowledge base. """Ingest content from a URL into OpenRAG.
# Note: This uses direct HTTP calls as the SDK doesn't yet support listing documents. Note: This uses the SDK's chat to trigger URL ingestion via the agent.
# """ """
# limit = arguments.get("limit", 50) url = arguments.get("url", "")
# try: if not url:
# async with get_client() as client: return [TextContent(type="text", text="Error: url is required")]
# response = await client.get("/api/v1/documents", params={"limit": limit})
# response.raise_for_status()
# data = response.json()
# documents = data.get("documents", []) if not url.startswith(("http://", "https://")):
return [TextContent(type="text", text="Error: url must start with http:// or https://")]
# if not documents: try:
# return [TextContent(type="text", text="No documents found in the knowledge base.")] # Use chat with a special prompt to trigger URL ingestion via the agent
client = get_openrag_client()
response = await client.chat.create(
message=f"Please ingest the content from this URL into the knowledge base: {url}",
)
# output_parts = [f"Found {len(documents)} document(s):\n"] return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")]
# for doc in documents: except AuthenticationError as e:
# filename = doc.get("filename", "Unknown") logger.error(f"Authentication error: {e.message}")
# chunks = doc.get("chunk_count", 0) return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# created = doc.get("created_at", "") except ServerError as e:
logger.error(f"Server error: {e.message}")
# output_parts.append(f"\n- **{filename}** ({chunks} chunks)") return [TextContent(type="text", text=f"Server error: {e.message}")]
# if created: except OpenRAGError as e:
# output_parts.append(f" - Added: {created[:10]}") logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
# return [TextContent(type="text", text="".join(output_parts))] except Exception as e:
logger.error(f"Ingest URL error: {e}")
# except Exception as e: return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
# logger.error(f"List documents error: {e}")
# return [TextContent(type="text", text=f"Error listing documents: {str(e)}")]
# async def _delete_document(arguments: dict) -> list[TextContent]: async def _get_task_status(arguments: dict) -> list[TextContent]:
# """Delete a document from the knowledge base using the SDK.""" """Get the status of an ingestion task."""
# filename = arguments.get("filename", "") task_id = arguments.get("task_id", "")
# if not filename: if not task_id:
# return [TextContent(type="text", text="Error: filename is required")] return [TextContent(type="text", text="Error: task_id is required")]
# try: try:
# client = get_openrag_client() client = get_openrag_client()
# response = await client.documents.delete(filename) status = await client.documents.get_task_status(task_id)
# return [TextContent( output_parts = [f"**Task Status: {status.status}**"]
# type="text", output_parts.append(f"\nTask ID: {status.task_id}")
# text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).", output_parts.append(f"\nTotal files: {status.total_files}")
# )] output_parts.append(f"\nProcessed: {status.processed_files}")
output_parts.append(f"\nSuccessful: {status.successful_files}")
output_parts.append(f"\nFailed: {status.failed_files}")
# except NotFoundError as e: if status.files:
# logger.error(f"Document not found: {e.message}") output_parts.append("\n\n**File Details:**")
# return [TextContent(type="text", text=f"Document not found: {e.message}")] for filename, file_status in status.files.items():
# except AuthenticationError as e: output_parts.append(f"\n- {filename}: {file_status}")
# logger.error(f"Authentication error: {e.message}")
# return [TextContent(type="text", text=f"Authentication error: {e.message}")] return [TextContent(type="text", text="".join(output_parts))]
# except ServerError as e:
# logger.error(f"Server error: {e.message}") except NotFoundError as e:
# return [TextContent(type="text", text=f"Server error: {e.message}")] logger.error(f"Task not found: {e.message}")
# except OpenRAGError as e: return [TextContent(type="text", text=f"Task not found: {e.message}")]
# logger.error(f"OpenRAG error: {e.message}") except AuthenticationError as e:
# return [TextContent(type="text", text=f"Error: {e.message}")] logger.error(f"Authentication error: {e.message}")
# except Exception as e: return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# logger.error(f"Delete document error: {e}") except OpenRAGError as e:
# return [TextContent(type="text", text=f"Error deleting document: {str(e)}")] logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e:
logger.error(f"Get task status error: {e}")
return [TextContent(type="text", text=f"Error getting task status: {str(e)}")]
async def _wait_for_task(arguments: dict) -> list[TextContent]:
"""Wait for an ingestion task to complete."""
task_id = arguments.get("task_id", "")
timeout = arguments.get("timeout", 300)
if not task_id:
return [TextContent(type="text", text="Error: task_id is required")]
try:
client = get_openrag_client()
status = await client.documents.wait_for_task(task_id, timeout=timeout)
output_parts = [f"**Task Completed: {status.status}**"]
output_parts.append(f"\nTask ID: {status.task_id}")
output_parts.append(f"\nTotal files: {status.total_files}")
output_parts.append(f"\nSuccessful: {status.successful_files}")
output_parts.append(f"\nFailed: {status.failed_files}")
if status.files:
output_parts.append("\n\n**File Details:**")
for filename, file_status in status.files.items():
output_parts.append(f"\n- {filename}: {file_status}")
return [TextContent(type="text", text="".join(output_parts))]
except TimeoutError as e:
logger.error(f"Wait for task timeout: {e}")
return [TextContent(type="text", text=f"Task did not complete within {timeout} seconds.")]
except NotFoundError as e:
logger.error(f"Task not found: {e.message}")
return [TextContent(type="text", text=f"Task not found: {e.message}")]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e:
logger.error(f"Wait for task error: {e}")
return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")]
async def _delete_document(arguments: dict) -> list[TextContent]:
"""Delete a document from the knowledge base using the SDK."""
filename = arguments.get("filename", "")
if not filename:
return [TextContent(type="text", text="Error: filename is required")]
try:
client = get_openrag_client()
response = await client.documents.delete(filename)
if response.success:
return [TextContent(
type="text",
text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).",
)]
else:
return [TextContent(
type="text",
text=f"Failed to delete '{filename}'.",
)]
except NotFoundError as e:
logger.error(f"Document not found: {e.message}")
return [TextContent(type="text", text=f"Document not found: {e.message}")]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except ServerError as e:
logger.error(f"Server error: {e.message}")
return [TextContent(type="text", text=f"Server error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e:
logger.error(f"Delete document error: {e}")
return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]

View file

@ -2,95 +2,139 @@
import logging import logging
from mcp.server import Server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from openrag_mcp.config import get_client from openrag_sdk import (
AuthenticationError,
OpenRAGError,
RateLimitError,
SearchFilters,
ServerError,
ValidationError,
)
from openrag_mcp.config import get_openrag_client
logger = logging.getLogger("openrag-mcp.search") logger = logging.getLogger("openrag-mcp.search")
def register_search_tools(server: Server) -> None: def get_search_tools() -> list[Tool]:
"""Register search-related tools with the MCP server.""" """Return search-related tools."""
return [
@server.list_tools() Tool(
async def list_search_tools() -> list[Tool]: name="openrag_search",
"""List search tools.""" description=(
return [ "Search the OpenRAG knowledge base using semantic search. "
Tool( "Returns matching document chunks with relevance scores. "
name="openrag_search", "Optionally filter by data sources or document types."
description=(
"Search the OpenRAG knowledge base using semantic search. "
"Returns matching document chunks with relevance scores."
),
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query",
},
"limit": {
"type": "integer",
"description": "Maximum number of results (default: 10)",
"default": 10,
},
},
"required": ["query"],
},
), ),
] inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query",
},
"limit": {
"type": "integer",
"description": "Maximum number of results (default: 10)",
"default": 10,
},
"score_threshold": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0)",
"default": 0,
},
"filter_id": {
"type": "string",
"description": "Optional knowledge filter ID to apply",
},
"data_sources": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of filenames to filter by",
},
"document_types": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
},
},
"required": ["query"],
},
),
]
@server.call_tool()
async def call_search_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle search tool calls."""
if name != "openrag_search":
return []
query = arguments.get("query", "") async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None:
limit = arguments.get("limit", 10) """Handle search tool calls. Returns None if tool not handled."""
if name != "openrag_search":
return None
if not query: query = arguments.get("query", "")
return [TextContent(type="text", text="Error: query is required")] limit = arguments.get("limit", 10)
score_threshold = arguments.get("score_threshold", 0)
filter_id = arguments.get("filter_id")
data_sources = arguments.get("data_sources")
document_types = arguments.get("document_types")
try: if not query:
async with get_client() as client: return [TextContent(type="text", text="Error: query is required")]
payload = {
"query": query,
"limit": limit,
}
response = await client.post("/api/v1/search", json=payload) try:
response.raise_for_status() client = get_openrag_client()
data = response.json()
results = data.get("results", []) # Build filters if provided
filters = None
if data_sources or document_types:
filters = SearchFilters(
data_sources=data_sources,
document_types=document_types,
)
if not results: response = await client.search.query(
return [TextContent(type="text", text="No results found.")] query=query,
limit=limit,
score_threshold=score_threshold,
filter_id=filter_id,
filters=filters,
)
# Format results if not response.results:
output_parts = [f"Found {len(results)} result(s):\n"] return [TextContent(type="text", text="No results found.")]
for i, result in enumerate(results, 1): # Format results
filename = result.get("filename", "Unknown") output_parts = [f"Found {len(response.results)} result(s):\n"]
score = result.get("score", 0)
content = result.get("content", "")
page = result.get("page_number")
output_parts.append(f"\n---\n**{i}. {filename}**") for i, result in enumerate(response.results, 1):
if page: output_parts.append(f"\n---\n**{i}. {result.filename}**")
output_parts.append(f" (page {page})") if result.page:
output_parts.append(f"\nRelevance: {score:.2f}\n") output_parts.append(f" (page {result.page})")
output_parts.append(f"\nRelevance: {result.score:.2f}\n")
# Truncate long content # Truncate long content
if len(content) > 500: content = result.text
content = content[:500] + "..." if len(content) > 500:
output_parts.append(f"\n{content}\n") content = content[:500] + "..."
output_parts.append(f"\n{content}\n")
return [TextContent(type="text", text="".join(output_parts))] return [TextContent(type="text", text="".join(output_parts))]
except Exception as e:
logger.error(f"Search error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except ValidationError as e:
logger.error(f"Validation error: {e.message}")
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
except RateLimitError as e:
logger.error(f"Rate limit error: {e.message}")
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
except ServerError as e:
logger.error(f"Server error: {e.message}")
return [TextContent(type="text", text=f"Server error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e:
logger.error(f"Search error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")]