116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
"""Chat tool for OpenRAG MCP server."""
|
|
|
|
import logging
|
|
|
|
from mcp.types import TextContent, Tool
|
|
|
|
from openrag_sdk import (
|
|
AuthenticationError,
|
|
OpenRAGError,
|
|
RateLimitError,
|
|
ServerError,
|
|
ValidationError,
|
|
)
|
|
|
|
from openrag_mcp.config import get_openrag_client
|
|
from openrag_mcp.tools.registry import register_tool
|
|
|
|
logger = logging.getLogger("openrag-mcp.chat")
|
|
|
|
|
|
# Tool definition
|
|
CHAT_TOOL = Tool(
|
|
name="openrag_chat",
|
|
description=(
|
|
"Send a message to OpenRAG and get a RAG-enhanced response. "
|
|
"The response is informed by documents in your knowledge base. "
|
|
"Use chat_id to continue a previous conversation, or filter_id "
|
|
"to apply a knowledge filter."
|
|
),
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {
|
|
"message": {
|
|
"type": "string",
|
|
"description": "Your question or message to send to OpenRAG",
|
|
},
|
|
"chat_id": {
|
|
"type": "string",
|
|
"description": "Optional conversation ID to continue a previous chat",
|
|
},
|
|
"filter_id": {
|
|
"type": "string",
|
|
"description": "Optional knowledge filter ID to apply",
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Maximum number of sources to retrieve (default: 10)",
|
|
"default": 10,
|
|
},
|
|
"score_threshold": {
|
|
"type": "number",
|
|
"description": "Minimum relevance score threshold (default: 0)",
|
|
"default": 0,
|
|
},
|
|
},
|
|
"required": ["message"],
|
|
},
|
|
)
|
|
|
|
|
|
async def handle_chat(arguments: dict) -> list[TextContent]:
|
|
"""Handle openrag_chat tool calls."""
|
|
message = arguments.get("message", "")
|
|
chat_id = arguments.get("chat_id")
|
|
filter_id = arguments.get("filter_id")
|
|
limit = arguments.get("limit", 10)
|
|
score_threshold = arguments.get("score_threshold", 0)
|
|
|
|
if not message:
|
|
return [TextContent(type="text", text="Error: message is required")]
|
|
|
|
try:
|
|
client = get_openrag_client()
|
|
response = await client.chat.create(
|
|
message=message,
|
|
chat_id=chat_id,
|
|
filter_id=filter_id,
|
|
limit=limit,
|
|
score_threshold=score_threshold,
|
|
)
|
|
|
|
# Build formatted response
|
|
output_parts = [response.response]
|
|
|
|
if response.sources:
|
|
output_parts.append("\n\n---\n**Sources:**")
|
|
for i, source in enumerate(response.sources, 1):
|
|
output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
|
|
|
|
if response.chat_id:
|
|
output_parts.append(f"\n\n_Chat ID: {response.chat_id}_")
|
|
|
|
return [TextContent(type="text", text="".join(output_parts))]
|
|
|
|
except AuthenticationError as e:
|
|
logger.error(f"Authentication error: {e.message}")
|
|
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
|
|
except ValidationError as e:
|
|
logger.error(f"Validation error: {e.message}")
|
|
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
|
|
except RateLimitError as e:
|
|
logger.error(f"Rate limit error: {e.message}")
|
|
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
|
|
except ServerError as e:
|
|
logger.error(f"Server error: {e.message}")
|
|
return [TextContent(type="text", text=f"Server error: {e.message}")]
|
|
except OpenRAGError as e:
|
|
logger.error(f"OpenRAG error: {e.message}")
|
|
return [TextContent(type="text", text=f"Error: {e.message}")]
|
|
except Exception as e:
|
|
logger.error(f"Chat error: {e}")
|
|
return [TextContent(type="text", text=f"Error: {str(e)}")]
|
|
|
|
|
|
# Register the tool
|
|
register_tool(CHAT_TOOL, handle_chat)
|