update mcp tools
This commit is contained in:
parent
cf8750d906
commit
a6a473aff0
5 changed files with 574 additions and 404 deletions
|
|
@ -5,11 +5,12 @@ import logging
|
|||
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from openrag_mcp.config import get_config
|
||||
from openrag_mcp.tools.chat import register_chat_tools
|
||||
# from openrag_mcp.tools.search import register_search_tools
|
||||
# from openrag_mcp.tools.documents import register_document_tools
|
||||
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
|
||||
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
|
||||
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
|
||||
|
||||
# Configure logging to stderr (stdout is used for MCP protocol)
|
||||
logging.basicConfig(
|
||||
|
|
@ -29,10 +30,35 @@ def create_server() -> Server:
|
|||
# Create server instance
|
||||
server = Server("openrag-mcp")
|
||||
|
||||
# Register all tools
|
||||
register_chat_tools(server)
|
||||
# register_search_tools(server)
|
||||
# register_document_tools(server)
|
||||
# Register a single list_tools handler that combines all tools
|
||||
@server.list_tools()
|
||||
async def list_all_tools() -> list[Tool]:
|
||||
"""List all available tools."""
|
||||
tools = []
|
||||
tools.extend(get_chat_tools())
|
||||
tools.extend(get_search_tools())
|
||||
tools.extend(get_document_tools())
|
||||
return tools
|
||||
|
||||
# Register a single call_tool handler that dispatches to appropriate handler
|
||||
@server.call_tool()
|
||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""Handle all tool calls by dispatching to the appropriate handler."""
|
||||
# Try each handler in order
|
||||
result = await handle_chat_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = await handle_search_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = await handle_document_tool(name, arguments)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Unknown tool
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
logger.info("OpenRAG MCP server initialized with all tools")
|
||||
return server
|
||||
|
|
@ -68,4 +94,3 @@ def main():
|
|||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
"""OpenRAG MCP tools."""
|
||||
|
||||
from openrag_mcp.tools.chat import register_chat_tools
|
||||
# from openrag_mcp.tools.search import register_search_tools
|
||||
# from openrag_mcp.tools.documents import register_document_tools
|
||||
|
||||
__all__ = ["register_chat_tools"]
|
||||
from openrag_mcp.tools.chat import get_chat_tools, handle_chat_tool
|
||||
from openrag_mcp.tools.search import get_search_tools, handle_search_tool
|
||||
from openrag_mcp.tools.documents import get_document_tools, handle_document_tool
|
||||
|
||||
__all__ = [
|
||||
"get_chat_tools",
|
||||
"handle_chat_tool",
|
||||
"get_search_tools",
|
||||
"handle_search_tool",
|
||||
"get_document_tools",
|
||||
"handle_document_tool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import logging
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from openrag_sdk import (
|
||||
|
|
@ -18,105 +17,101 @@ from openrag_mcp.config import get_openrag_client
|
|||
logger = logging.getLogger("openrag-mcp.chat")
|
||||
|
||||
|
||||
def register_chat_tools(server: Server) -> None:
|
||||
"""Register chat-related tools with the MCP server."""
|
||||
|
||||
@server.list_tools()
|
||||
async def list_chat_tools() -> list[Tool]:
|
||||
"""List chat tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_chat",
|
||||
description=(
|
||||
"Send a message to OpenRAG and get a RAG-enhanced response. "
|
||||
"The response is informed by documents in your knowledge base. "
|
||||
"Use chat_id to continue a previous conversation, or filter_id "
|
||||
"to apply a knowledge filter."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your question or message to send to OpenRAG",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional conversation ID to continue a previous chat",
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of sources to retrieve (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
def get_chat_tools() -> list[Tool]:
|
||||
"""Return chat-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_chat",
|
||||
description=(
|
||||
"Send a message to OpenRAG and get a RAG-enhanced response. "
|
||||
"The response is informed by documents in your knowledge base. "
|
||||
"Use chat_id to continue a previous conversation, or filter_id "
|
||||
"to apply a knowledge filter."
|
||||
),
|
||||
]
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Your question or message to send to OpenRAG",
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional conversation ID to continue a previous chat",
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of sources to retrieve (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@server.call_tool()
|
||||
async def call_chat_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""Handle chat tool calls."""
|
||||
if name != "openrag_chat":
|
||||
return []
|
||||
|
||||
message = arguments.get("message", "")
|
||||
chat_id = arguments.get("chat_id")
|
||||
filter_id = arguments.get("filter_id")
|
||||
limit = arguments.get("limit", 10)
|
||||
score_threshold = arguments.get("score_threshold", 0)
|
||||
async def handle_chat_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle chat tool calls. Returns None if tool not handled."""
|
||||
if name != "openrag_chat":
|
||||
return None
|
||||
|
||||
if not message:
|
||||
return [TextContent(type="text", text="Error: message is required")]
|
||||
message = arguments.get("message", "")
|
||||
chat_id = arguments.get("chat_id")
|
||||
filter_id = arguments.get("filter_id")
|
||||
limit = arguments.get("limit", 10)
|
||||
score_threshold = arguments.get("score_threshold", 0)
|
||||
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
response = await client.chat.create(
|
||||
message=message,
|
||||
chat_id=chat_id,
|
||||
filter_id=filter_id,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
if not message:
|
||||
return [TextContent(type="text", text="Error: message is required")]
|
||||
|
||||
# Build formatted response
|
||||
output_parts = [response.response]
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
response = await client.chat.create(
|
||||
message=message,
|
||||
chat_id=chat_id,
|
||||
filter_id=filter_id,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
if response.sources:
|
||||
output_parts.append("\n\n---\n**Sources:**")
|
||||
for i, source in enumerate(response.sources, 1):
|
||||
output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
|
||||
# Build formatted response
|
||||
output_parts = [response.response]
|
||||
|
||||
if response.chat_id:
|
||||
output_parts.append(f"\n\n_Chat ID: {response.chat_id}_")
|
||||
if response.sources:
|
||||
output_parts.append("\n\n---\n**Sources:**")
|
||||
for i, source in enumerate(response.sources, 1):
|
||||
output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
|
||||
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
if response.chat_id:
|
||||
output_parts.append(f"\n\n_Chat ID: {response.chat_id}_")
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
||||
except RateLimitError as e:
|
||||
logger.error(f"Rate limit error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Chat error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
||||
except RateLimitError as e:
|
||||
logger.error(f"Rate limit error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Chat error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
|
|
|||
|
|
@ -1,257 +1,357 @@
|
|||
# """Document tools for OpenRAG MCP server."""
|
||||
"""Document tools for OpenRAG MCP server."""
|
||||
|
||||
# import logging
|
||||
# from pathlib import Path
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# from mcp.server import Server
|
||||
# from mcp.types import TextContent, Tool
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
# from openrag_sdk import (
|
||||
# AuthenticationError,
|
||||
# NotFoundError,
|
||||
# OpenRAGError,
|
||||
# RateLimitError,
|
||||
# ServerError,
|
||||
# ValidationError,
|
||||
# )
|
||||
from openrag_sdk import (
|
||||
AuthenticationError,
|
||||
NotFoundError,
|
||||
OpenRAGError,
|
||||
RateLimitError,
|
||||
ServerError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
# from openrag_mcp.config import get_client, get_openrag_client
|
||||
from openrag_mcp.config import get_openrag_client
|
||||
|
||||
# logger = logging.getLogger("openrag-mcp.documents")
|
||||
logger = logging.getLogger("openrag-mcp.documents")
|
||||
|
||||
|
||||
# def register_document_tools(server: Server) -> None:
|
||||
# """Register document-related tools with the MCP server."""
|
||||
|
||||
# @server.list_tools()
|
||||
# async def list_document_tools() -> list[Tool]:
|
||||
# """List document tools."""
|
||||
# return [
|
||||
# Tool(
|
||||
# name="openrag_ingest_file",
|
||||
# description=(
|
||||
# "Ingest a local file into the OpenRAG knowledge base. "
|
||||
# "Supported formats: PDF, DOCX, TXT, MD, HTML, and more."
|
||||
# ),
|
||||
# inputSchema={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "file_path": {
|
||||
# "type": "string",
|
||||
# "description": "Absolute path to the file to ingest",
|
||||
# },
|
||||
# },
|
||||
# "required": ["file_path"],
|
||||
# },
|
||||
# ),
|
||||
# Tool(
|
||||
# name="openrag_ingest_url",
|
||||
# description=(
|
||||
# "Ingest content from a URL into the OpenRAG knowledge base. "
|
||||
# "The URL content will be fetched, processed, and stored."
|
||||
# ),
|
||||
# inputSchema={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "url": {
|
||||
# "type": "string",
|
||||
# "description": "The URL to fetch and ingest",
|
||||
# },
|
||||
# },
|
||||
# "required": ["url"],
|
||||
# },
|
||||
# ),
|
||||
# Tool(
|
||||
# name="openrag_list_documents",
|
||||
# description="List documents in the OpenRAG knowledge base.",
|
||||
# inputSchema={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "limit": {
|
||||
# "type": "integer",
|
||||
# "description": "Maximum number of documents to return (default: 50)",
|
||||
# "default": 50,
|
||||
# },
|
||||
# },
|
||||
# "required": [],
|
||||
# },
|
||||
# ),
|
||||
# Tool(
|
||||
# name="openrag_delete_document",
|
||||
# description="Delete a document from the OpenRAG knowledge base.",
|
||||
# inputSchema={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "filename": {
|
||||
# "type": "string",
|
||||
# "description": "Name of the file to delete",
|
||||
# },
|
||||
# },
|
||||
# "required": ["filename"],
|
||||
# },
|
||||
# ),
|
||||
# ]
|
||||
|
||||
# @server.call_tool()
|
||||
# async def call_document_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
# """Handle document tool calls."""
|
||||
# if name == "openrag_ingest_file":
|
||||
# return await _ingest_file(arguments)
|
||||
# elif name == "openrag_ingest_url":
|
||||
# return await _ingest_url(arguments)
|
||||
# elif name == "openrag_list_documents":
|
||||
# return await _list_documents(arguments)
|
||||
# elif name == "openrag_delete_document":
|
||||
# return await _delete_document(arguments)
|
||||
# return []
|
||||
def get_document_tools() -> list[Tool]:
|
||||
"""Return document-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_ingest_file",
|
||||
description=(
|
||||
"Ingest a local file into the OpenRAG knowledge base. "
|
||||
"Supported formats: PDF, DOCX, TXT, MD, HTML, and more. "
|
||||
"By default waits for ingestion to complete. Set wait=false to return immediately."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to ingest",
|
||||
},
|
||||
"wait": {
|
||||
"type": "boolean",
|
||||
"description": "Wait for ingestion to complete (default: true). Set to false to return immediately with task_id.",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_ingest_url",
|
||||
description=(
|
||||
"Ingest content from a URL into the OpenRAG knowledge base. "
|
||||
"The URL content will be fetched, processed, and stored."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The URL to fetch and ingest",
|
||||
},
|
||||
},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_get_task_status",
|
||||
description=(
|
||||
"Check the status of an ingestion task. "
|
||||
"Use the task_id returned from openrag_ingest_file when wait=false."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to check status for",
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_wait_for_task",
|
||||
description=(
|
||||
"Wait for an ingestion task to complete. "
|
||||
"Polls the task status until it completes or fails."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task ID to wait for",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "Maximum seconds to wait (default: 300)",
|
||||
"default": 300,
|
||||
},
|
||||
},
|
||||
"required": ["task_id"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="openrag_delete_document",
|
||||
description="Delete a document from the OpenRAG knowledge base.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file to delete",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# async def _ingest_file(arguments: dict) -> list[TextContent]:
|
||||
# """Ingest a local file into OpenRAG using the SDK."""
|
||||
# file_path = arguments.get("file_path", "")
|
||||
|
||||
# if not file_path:
|
||||
# return [TextContent(type="text", text="Error: file_path is required")]
|
||||
|
||||
# path = Path(file_path)
|
||||
|
||||
# if not path.exists():
|
||||
# return [TextContent(type="text", text=f"Error: File not found: {file_path}")]
|
||||
|
||||
# if not path.is_file():
|
||||
# return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")]
|
||||
|
||||
# try:
|
||||
# client = get_openrag_client()
|
||||
# # Use wait=False to return immediately with task_id
|
||||
# response = await client.documents.ingest(file_path=path, wait=False)
|
||||
|
||||
# result = f"Successfully queued '{response.filename or path.name}' for ingestion."
|
||||
# if response.task_id:
|
||||
# result += f"\nTask ID: {response.task_id}"
|
||||
|
||||
# return [TextContent(type="text", text=result)]
|
||||
|
||||
# except AuthenticationError as e:
|
||||
# logger.error(f"Authentication error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
# except ValidationError as e:
|
||||
# logger.error(f"Validation error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
||||
# except RateLimitError as e:
|
||||
# logger.error(f"Rate limit error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
||||
# except ServerError as e:
|
||||
# logger.error(f"Server error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
# except OpenRAGError as e:
|
||||
# logger.error(f"OpenRAG error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
# except Exception as e:
|
||||
# logger.error(f"Ingest file error: {e}")
|
||||
# return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
|
||||
async def handle_document_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle document tool calls. Returns None if tool not handled."""
|
||||
if name == "openrag_ingest_file":
|
||||
return await _ingest_file(arguments)
|
||||
elif name == "openrag_ingest_url":
|
||||
return await _ingest_url(arguments)
|
||||
elif name == "openrag_get_task_status":
|
||||
return await _get_task_status(arguments)
|
||||
elif name == "openrag_wait_for_task":
|
||||
return await _wait_for_task(arguments)
|
||||
elif name == "openrag_delete_document":
|
||||
return await _delete_document(arguments)
|
||||
return None
|
||||
|
||||
|
||||
# async def _ingest_url(arguments: dict) -> list[TextContent]:
|
||||
# """Ingest content from a URL into OpenRAG.
|
||||
async def _ingest_file(arguments: dict) -> list[TextContent]:
|
||||
"""Ingest a local file into OpenRAG using the SDK."""
|
||||
file_path = arguments.get("file_path", "")
|
||||
wait = arguments.get("wait", True)
|
||||
|
||||
# Note: This uses the SDK's chat to trigger URL ingestion via the agent.
|
||||
# """
|
||||
# url = arguments.get("url", "")
|
||||
if not file_path:
|
||||
return [TextContent(type="text", text="Error: file_path is required")]
|
||||
|
||||
# if not url:
|
||||
# return [TextContent(type="text", text="Error: url is required")]
|
||||
path = Path(file_path)
|
||||
|
||||
# if not url.startswith(("http://", "https://")):
|
||||
# return [TextContent(type="text", text="Error: url must start with http:// or https://")]
|
||||
if not path.exists():
|
||||
return [TextContent(type="text", text=f"Error: File not found: {file_path}")]
|
||||
|
||||
# try:
|
||||
# # Use chat with a special prompt to trigger URL ingestion via the agent
|
||||
# client = get_openrag_client()
|
||||
# response = await client.chat.create(
|
||||
# message=f"Please ingest the content from this URL into the knowledge base: {url}",
|
||||
# )
|
||||
if not path.is_file():
|
||||
return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")]
|
||||
|
||||
# return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")]
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
response = await client.documents.ingest(file_path=path, wait=wait)
|
||||
|
||||
# except AuthenticationError as e:
|
||||
# logger.error(f"Authentication error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
# except ServerError as e:
|
||||
# logger.error(f"Server error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
# except OpenRAGError as e:
|
||||
# logger.error(f"OpenRAG error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
# except Exception as e:
|
||||
# logger.error(f"Ingest URL error: {e}")
|
||||
# return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
|
||||
if wait:
|
||||
# Response is IngestTaskStatus when wait=True
|
||||
status = response.status
|
||||
successful = response.successful_files
|
||||
failed = response.failed_files
|
||||
|
||||
if status == "completed":
|
||||
result = f"Successfully ingested '{path.name}'."
|
||||
result += f"\nStatus: {status}"
|
||||
result += f"\nSuccessful files: {successful}"
|
||||
if failed > 0:
|
||||
result += f"\nFailed files: {failed}"
|
||||
else:
|
||||
result = f"Ingestion finished with status: {status}"
|
||||
result += f"\nSuccessful files: {successful}"
|
||||
result += f"\nFailed files: {failed}"
|
||||
else:
|
||||
# Response is IngestResponse when wait=False
|
||||
result = f"Successfully queued '{response.filename or path.name}' for ingestion."
|
||||
if response.task_id:
|
||||
result += f"\nTask ID: {response.task_id}"
|
||||
result += "\n\nUse openrag_get_task_status or openrag_wait_for_task to check progress."
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
||||
except RateLimitError as e:
|
||||
logger.error(f"Rate limit error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Timeout error: {e}")
|
||||
return [TextContent(type="text", text=f"Ingestion timed out: {str(e)}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Ingest file error: {e}")
|
||||
return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
|
||||
|
||||
|
||||
# async def _list_documents(arguments: dict) -> list[TextContent]:
|
||||
# """List documents in the knowledge base.
|
||||
async def _ingest_url(arguments: dict) -> list[TextContent]:
|
||||
"""Ingest content from a URL into OpenRAG.
|
||||
|
||||
# Note: This uses direct HTTP calls as the SDK doesn't yet support listing documents.
|
||||
# """
|
||||
# limit = arguments.get("limit", 50)
|
||||
Note: This uses the SDK's chat to trigger URL ingestion via the agent.
|
||||
"""
|
||||
url = arguments.get("url", "")
|
||||
|
||||
# try:
|
||||
# async with get_client() as client:
|
||||
# response = await client.get("/api/v1/documents", params={"limit": limit})
|
||||
# response.raise_for_status()
|
||||
# data = response.json()
|
||||
if not url:
|
||||
return [TextContent(type="text", text="Error: url is required")]
|
||||
|
||||
# documents = data.get("documents", [])
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return [TextContent(type="text", text="Error: url must start with http:// or https://")]
|
||||
|
||||
# if not documents:
|
||||
# return [TextContent(type="text", text="No documents found in the knowledge base.")]
|
||||
try:
|
||||
# Use chat with a special prompt to trigger URL ingestion via the agent
|
||||
client = get_openrag_client()
|
||||
response = await client.chat.create(
|
||||
message=f"Please ingest the content from this URL into the knowledge base: {url}",
|
||||
)
|
||||
|
||||
# output_parts = [f"Found {len(documents)} document(s):\n"]
|
||||
return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")]
|
||||
|
||||
# for doc in documents:
|
||||
# filename = doc.get("filename", "Unknown")
|
||||
# chunks = doc.get("chunk_count", 0)
|
||||
# created = doc.get("created_at", "")
|
||||
|
||||
# output_parts.append(f"\n- **{filename}** ({chunks} chunks)")
|
||||
# if created:
|
||||
# output_parts.append(f" - Added: {created[:10]}")
|
||||
|
||||
# return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"List documents error: {e}")
|
||||
# return [TextContent(type="text", text=f"Error listing documents: {str(e)}")]
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Ingest URL error: {e}")
|
||||
return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
|
||||
|
||||
|
||||
# async def _delete_document(arguments: dict) -> list[TextContent]:
|
||||
# """Delete a document from the knowledge base using the SDK."""
|
||||
# filename = arguments.get("filename", "")
|
||||
async def _get_task_status(arguments: dict) -> list[TextContent]:
|
||||
"""Get the status of an ingestion task."""
|
||||
task_id = arguments.get("task_id", "")
|
||||
|
||||
# if not filename:
|
||||
# return [TextContent(type="text", text="Error: filename is required")]
|
||||
if not task_id:
|
||||
return [TextContent(type="text", text="Error: task_id is required")]
|
||||
|
||||
# try:
|
||||
# client = get_openrag_client()
|
||||
# response = await client.documents.delete(filename)
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
status = await client.documents.get_task_status(task_id)
|
||||
|
||||
# return [TextContent(
|
||||
# type="text",
|
||||
# text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).",
|
||||
# )]
|
||||
output_parts = [f"**Task Status: {status.status}**"]
|
||||
output_parts.append(f"\nTask ID: {status.task_id}")
|
||||
output_parts.append(f"\nTotal files: {status.total_files}")
|
||||
output_parts.append(f"\nProcessed: {status.processed_files}")
|
||||
output_parts.append(f"\nSuccessful: {status.successful_files}")
|
||||
output_parts.append(f"\nFailed: {status.failed_files}")
|
||||
|
||||
# except NotFoundError as e:
|
||||
# logger.error(f"Document not found: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Document not found: {e.message}")]
|
||||
# except AuthenticationError as e:
|
||||
# logger.error(f"Authentication error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
# except ServerError as e:
|
||||
# logger.error(f"Server error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
# except OpenRAGError as e:
|
||||
# logger.error(f"OpenRAG error: {e.message}")
|
||||
# return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
# except Exception as e:
|
||||
# logger.error(f"Delete document error: {e}")
|
||||
# return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]
|
||||
if status.files:
|
||||
output_parts.append("\n\n**File Details:**")
|
||||
for filename, file_status in status.files.items():
|
||||
output_parts.append(f"\n- {filename}: {file_status}")
|
||||
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
except NotFoundError as e:
|
||||
logger.error(f"Task not found: {e.message}")
|
||||
return [TextContent(type="text", text=f"Task not found: {e.message}")]
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Get task status error: {e}")
|
||||
return [TextContent(type="text", text=f"Error getting task status: {str(e)}")]
|
||||
|
||||
|
||||
async def _wait_for_task(arguments: dict) -> list[TextContent]:
|
||||
"""Wait for an ingestion task to complete."""
|
||||
task_id = arguments.get("task_id", "")
|
||||
timeout = arguments.get("timeout", 300)
|
||||
|
||||
if not task_id:
|
||||
return [TextContent(type="text", text="Error: task_id is required")]
|
||||
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
status = await client.documents.wait_for_task(task_id, timeout=timeout)
|
||||
|
||||
output_parts = [f"**Task Completed: {status.status}**"]
|
||||
output_parts.append(f"\nTask ID: {status.task_id}")
|
||||
output_parts.append(f"\nTotal files: {status.total_files}")
|
||||
output_parts.append(f"\nSuccessful: {status.successful_files}")
|
||||
output_parts.append(f"\nFailed: {status.failed_files}")
|
||||
|
||||
if status.files:
|
||||
output_parts.append("\n\n**File Details:**")
|
||||
for filename, file_status in status.files.items():
|
||||
output_parts.append(f"\n- {filename}: {file_status}")
|
||||
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
except TimeoutError as e:
|
||||
logger.error(f"Wait for task timeout: {e}")
|
||||
return [TextContent(type="text", text=f"Task did not complete within {timeout} seconds.")]
|
||||
except NotFoundError as e:
|
||||
logger.error(f"Task not found: {e.message}")
|
||||
return [TextContent(type="text", text=f"Task not found: {e.message}")]
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Wait for task error: {e}")
|
||||
return [TextContent(type="text", text=f"Error waiting for task: {str(e)}")]
|
||||
|
||||
|
||||
async def _delete_document(arguments: dict) -> list[TextContent]:
|
||||
"""Delete a document from the knowledge base using the SDK."""
|
||||
filename = arguments.get("filename", "")
|
||||
|
||||
if not filename:
|
||||
return [TextContent(type="text", text="Error: filename is required")]
|
||||
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
response = await client.documents.delete(filename)
|
||||
|
||||
if response.success:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).",
|
||||
)]
|
||||
else:
|
||||
return [TextContent(
|
||||
type="text",
|
||||
text=f"Failed to delete '{filename}'.",
|
||||
)]
|
||||
|
||||
except NotFoundError as e:
|
||||
logger.error(f"Document not found: {e.message}")
|
||||
return [TextContent(type="text", text=f"Document not found: {e.message}")]
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Delete document error: {e}")
|
||||
return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]
|
||||
|
|
|
|||
|
|
@ -2,95 +2,139 @@
|
|||
|
||||
import logging
|
||||
|
||||
from mcp.server import Server
|
||||
from mcp.types import TextContent, Tool
|
||||
|
||||
from openrag_mcp.config import get_client
|
||||
from openrag_sdk import (
|
||||
AuthenticationError,
|
||||
OpenRAGError,
|
||||
RateLimitError,
|
||||
SearchFilters,
|
||||
ServerError,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
from openrag_mcp.config import get_openrag_client
|
||||
|
||||
logger = logging.getLogger("openrag-mcp.search")
|
||||
|
||||
|
||||
def register_search_tools(server: Server) -> None:
|
||||
"""Register search-related tools with the MCP server."""
|
||||
|
||||
@server.list_tools()
|
||||
async def list_search_tools() -> list[Tool]:
|
||||
"""List search tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_search",
|
||||
description=(
|
||||
"Search the OpenRAG knowledge base using semantic search. "
|
||||
"Returns matching document chunks with relevance scores."
|
||||
),
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
def get_search_tools() -> list[Tool]:
|
||||
"""Return search-related tools."""
|
||||
return [
|
||||
Tool(
|
||||
name="openrag_search",
|
||||
description=(
|
||||
"Search the OpenRAG knowledge base using semantic search. "
|
||||
"Returns matching document chunks with relevance scores. "
|
||||
"Optionally filter by data sources or document types."
|
||||
),
|
||||
]
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 10)",
|
||||
"default": 10,
|
||||
},
|
||||
"score_threshold": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score threshold (default: 0)",
|
||||
"default": 0,
|
||||
},
|
||||
"filter_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge filter ID to apply",
|
||||
},
|
||||
"data_sources": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of filenames to filter by",
|
||||
},
|
||||
"document_types": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of MIME types to filter by (e.g., 'application/pdf')",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@server.call_tool()
|
||||
async def call_search_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||
"""Handle search tool calls."""
|
||||
if name != "openrag_search":
|
||||
return []
|
||||
|
||||
query = arguments.get("query", "")
|
||||
limit = arguments.get("limit", 10)
|
||||
async def handle_search_tool(name: str, arguments: dict) -> list[TextContent] | None:
|
||||
"""Handle search tool calls. Returns None if tool not handled."""
|
||||
if name != "openrag_search":
|
||||
return None
|
||||
|
||||
if not query:
|
||||
return [TextContent(type="text", text="Error: query is required")]
|
||||
query = arguments.get("query", "")
|
||||
limit = arguments.get("limit", 10)
|
||||
score_threshold = arguments.get("score_threshold", 0)
|
||||
filter_id = arguments.get("filter_id")
|
||||
data_sources = arguments.get("data_sources")
|
||||
document_types = arguments.get("document_types")
|
||||
|
||||
try:
|
||||
async with get_client() as client:
|
||||
payload = {
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
}
|
||||
if not query:
|
||||
return [TextContent(type="text", text="Error: query is required")]
|
||||
|
||||
response = await client.post("/api/v1/search", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
try:
|
||||
client = get_openrag_client()
|
||||
|
||||
results = data.get("results", [])
|
||||
# Build filters if provided
|
||||
filters = None
|
||||
if data_sources or document_types:
|
||||
filters = SearchFilters(
|
||||
data_sources=data_sources,
|
||||
document_types=document_types,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return [TextContent(type="text", text="No results found.")]
|
||||
response = await client.search.query(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter_id=filter_id,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
# Format results
|
||||
output_parts = [f"Found {len(results)} result(s):\n"]
|
||||
if not response.results:
|
||||
return [TextContent(type="text", text="No results found.")]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
filename = result.get("filename", "Unknown")
|
||||
score = result.get("score", 0)
|
||||
content = result.get("content", "")
|
||||
page = result.get("page_number")
|
||||
# Format results
|
||||
output_parts = [f"Found {len(response.results)} result(s):\n"]
|
||||
|
||||
output_parts.append(f"\n---\n**{i}. {filename}**")
|
||||
if page:
|
||||
output_parts.append(f" (page {page})")
|
||||
output_parts.append(f"\nRelevance: {score:.2f}\n")
|
||||
for i, result in enumerate(response.results, 1):
|
||||
output_parts.append(f"\n---\n**{i}. {result.filename}**")
|
||||
if result.page:
|
||||
output_parts.append(f" (page {result.page})")
|
||||
output_parts.append(f"\nRelevance: {result.score:.2f}\n")
|
||||
|
||||
# Truncate long content
|
||||
if len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
output_parts.append(f"\n{content}\n")
|
||||
# Truncate long content
|
||||
content = result.text
|
||||
if len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
output_parts.append(f"\n{content}\n")
|
||||
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
return [TextContent(type="text", text="".join(output_parts))]
|
||||
|
||||
except AuthenticationError as e:
|
||||
logger.error(f"Authentication error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
||||
except RateLimitError as e:
|
||||
logger.error(f"Rate limit error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
||||
except ServerError as e:
|
||||
logger.error(f"Server error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
||||
except OpenRAGError as e:
|
||||
logger.error(f"OpenRAG error: {e.message}")
|
||||
return [TextContent(type="text", text=f"Error: {e.message}")]
|
||||
except Exception as e:
|
||||
logger.error(f"Search error: {e}")
|
||||
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue