Add Chat mcp using sdk

This commit is contained in:
Edwin Jose 2025-12-26 13:43:06 -05:00
parent bf716b49c8
commit cf8750d906
9 changed files with 1650 additions and 1560 deletions

View file

@ -35,7 +35,8 @@ dependencies = [
"docling-serve==1.5.0", "docling-serve==1.5.0",
"docling-core==2.48.1", "docling-core==2.48.1",
"easyocr>=1.7.1; sys_platform != 'darwin'", "easyocr>=1.7.1; sys_platform != 'darwin'",
"zxcvbn>=4.5.0" "zxcvbn>=4.5.0",
"openrag-sdk>=0.1.0",
] ]
[dependency-groups] [dependency-groups]

View file

@ -19,6 +19,7 @@ classifiers = [
dependencies = [ dependencies = [
"mcp>=1.0.0", "mcp>=1.0.0",
"httpx>=0.27.0", "httpx>=0.27.0",
"openrag-sdk>=0.1.0",
] ]
[project.scripts] [project.scripts]

View file

@ -2,6 +2,8 @@
import os import os
from openrag_sdk import OpenRAGClient
class Config: class Config:
"""Configuration loaded from environment variables.""" """Configuration loaded from environment variables."""
@ -26,6 +28,7 @@ class Config:
_config: Config | None = None _config: Config | None = None
_openrag_client: OpenRAGClient | None = None
def get_config() -> Config: def get_config() -> Config:
@ -36,8 +39,21 @@ def get_config() -> Config:
return _config return _config
def get_openrag_client() -> OpenRAGClient:
"""Get singleton OpenRAGClient instance."""
global _openrag_client
if _openrag_client is None:
# OpenRAGClient reads OPENRAG_API_KEY and OPENRAG_URL from env
_openrag_client = OpenRAGClient()
return _openrag_client
def get_client(): def get_client():
"""Get an httpx async client configured for OpenRAG.""" """Get an httpx async client configured for OpenRAG.
This is kept for backward compatibility with operations
not yet supported by the SDK (list_documents, ingest_url).
"""
import httpx import httpx
config = get_config() config = get_config()
@ -46,4 +62,3 @@ def get_client():
headers=config.headers, headers=config.headers,
timeout=60.0, timeout=60.0,
) )

View file

@ -8,8 +8,8 @@ from mcp.server.stdio import stdio_server
from openrag_mcp.config import get_config from openrag_mcp.config import get_config
from openrag_mcp.tools.chat import register_chat_tools from openrag_mcp.tools.chat import register_chat_tools
from openrag_mcp.tools.search import register_search_tools # from openrag_mcp.tools.search import register_search_tools
from openrag_mcp.tools.documents import register_document_tools # from openrag_mcp.tools.documents import register_document_tools
# Configure logging to stderr (stdout is used for MCP protocol) # Configure logging to stderr (stdout is used for MCP protocol)
logging.basicConfig( logging.basicConfig(
@ -31,8 +31,8 @@ def create_server() -> Server:
# Register all tools # Register all tools
register_chat_tools(server) register_chat_tools(server)
register_search_tools(server) # register_search_tools(server)
register_document_tools(server) # register_document_tools(server)
logger.info("OpenRAG MCP server initialized with all tools") logger.info("OpenRAG MCP server initialized with all tools")
return server return server

View file

@ -1,8 +1,8 @@
"""OpenRAG MCP tools.""" """OpenRAG MCP tools."""
from openrag_mcp.tools.chat import register_chat_tools from openrag_mcp.tools.chat import register_chat_tools
from openrag_mcp.tools.search import register_search_tools # from openrag_mcp.tools.search import register_search_tools
from openrag_mcp.tools.documents import register_document_tools # from openrag_mcp.tools.documents import register_document_tools
__all__ = ["register_chat_tools", "register_search_tools", "register_document_tools"] __all__ = ["register_chat_tools"]

View file

@ -1,12 +1,19 @@
"""Chat tool for OpenRAG MCP server.""" """Chat tool for OpenRAG MCP server."""
import json
import logging import logging
from mcp.server import Server from mcp.server import Server
from mcp.types import TextContent, Tool from mcp.types import TextContent, Tool
from openrag_mcp.config import get_client from openrag_sdk import (
AuthenticationError,
OpenRAGError,
RateLimitError,
ServerError,
ValidationError,
)
from openrag_mcp.config import get_openrag_client
logger = logging.getLogger("openrag-mcp.chat") logger = logging.getLogger("openrag-mcp.chat")
@ -23,7 +30,8 @@ def register_chat_tools(server: Server) -> None:
description=( description=(
"Send a message to OpenRAG and get a RAG-enhanced response. " "Send a message to OpenRAG and get a RAG-enhanced response. "
"The response is informed by documents in your knowledge base. " "The response is informed by documents in your knowledge base. "
"Use chat_id to continue a previous conversation." "Use chat_id to continue a previous conversation, or filter_id "
"to apply a knowledge filter."
), ),
inputSchema={ inputSchema={
"type": "object", "type": "object",
@ -36,11 +44,20 @@ def register_chat_tools(server: Server) -> None:
"type": "string", "type": "string",
"description": "Optional conversation ID to continue a previous chat", "description": "Optional conversation ID to continue a previous chat",
}, },
"filter_id": {
"type": "string",
"description": "Optional knowledge filter ID to apply",
},
"limit": { "limit": {
"type": "integer", "type": "integer",
"description": "Maximum number of sources to retrieve (default: 10)", "description": "Maximum number of sources to retrieve (default: 10)",
"default": 10, "default": 10,
}, },
"score_threshold": {
"type": "number",
"description": "Minimum relevance score threshold (default: 0)",
"default": 0,
},
}, },
"required": ["message"], "required": ["message"],
}, },
@ -55,46 +72,51 @@ def register_chat_tools(server: Server) -> None:
message = arguments.get("message", "") message = arguments.get("message", "")
chat_id = arguments.get("chat_id") chat_id = arguments.get("chat_id")
filter_id = arguments.get("filter_id")
limit = arguments.get("limit", 10) limit = arguments.get("limit", 10)
score_threshold = arguments.get("score_threshold", 0)
if not message: if not message:
return [TextContent(type="text", text="Error: message is required")] return [TextContent(type="text", text="Error: message is required")]
try: try:
async with get_client() as client: client = get_openrag_client()
payload = { response = await client.chat.create(
"message": message, message=message,
"stream": False, chat_id=chat_id,
"limit": limit, filter_id=filter_id,
} limit=limit,
if chat_id: score_threshold=score_threshold,
payload["chat_id"] = chat_id )
response = await client.post("/api/v1/chat", json=payload) # Build formatted response
response.raise_for_status() output_parts = [response.response]
data = response.json()
# Format the response if response.sources:
result_text = data.get("response", "") output_parts.append("\n\n---\n**Sources:**")
sources = data.get("sources", []) for i, source in enumerate(response.sources, 1):
new_chat_id = data.get("chat_id") output_parts.append(f"\n{i}. {source.filename} (relevance: {source.score:.2f})")
# Build formatted response if response.chat_id:
output_parts = [result_text] output_parts.append(f"\n\n_Chat ID: {response.chat_id}_")
if sources: return [TextContent(type="text", text="".join(output_parts))]
output_parts.append("\n\n---\n**Sources:**")
for i, source in enumerate(sources, 1):
filename = source.get("filename", "Unknown")
score = source.get("score", 0)
output_parts.append(f"\n{i}. {filename} (relevance: {score:.2f})")
if new_chat_id:
output_parts.append(f"\n\n_Chat ID: {new_chat_id}_")
return [TextContent(type="text", text="".join(output_parts))]
except AuthenticationError as e:
logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Authentication error: {e.message}")]
except ValidationError as e:
logger.error(f"Validation error: {e.message}")
return [TextContent(type="text", text=f"Invalid request: {e.message}")]
except RateLimitError as e:
logger.error(f"Rate limit error: {e.message}")
return [TextContent(type="text", text=f"Rate limited: {e.message}")]
except ServerError as e:
logger.error(f"Server error: {e.message}")
return [TextContent(type="text", text=f"Server error: {e.message}")]
except OpenRAGError as e:
logger.error(f"OpenRAG error: {e.message}")
return [TextContent(type="text", text=f"Error: {e.message}")]
except Exception as e: except Exception as e:
logger.error(f"Chat error: {e}") logger.error(f"Chat error: {e}")
return [TextContent(type="text", text=f"Error: {str(e)}")] return [TextContent(type="text", text=f"Error: {str(e)}")]

View file

@ -1,236 +1,257 @@
"""Document tools for OpenRAG MCP server.""" # """Document tools for OpenRAG MCP server."""
import logging # import logging
import os # from pathlib import Path
from pathlib import Path
from mcp.server import Server # from mcp.server import Server
from mcp.types import TextContent, Tool # from mcp.types import TextContent, Tool
from openrag_mcp.config import get_client # from openrag_sdk import (
# AuthenticationError,
# NotFoundError,
# OpenRAGError,
# RateLimitError,
# ServerError,
# ValidationError,
# )
logger = logging.getLogger("openrag-mcp.documents") # from openrag_mcp.config import get_client, get_openrag_client
# logger = logging.getLogger("openrag-mcp.documents")
def register_document_tools(server: Server) -> None: # def register_document_tools(server: Server) -> None:
"""Register document-related tools with the MCP server.""" # """Register document-related tools with the MCP server."""
@server.list_tools() # @server.list_tools()
async def list_document_tools() -> list[Tool]: # async def list_document_tools() -> list[Tool]:
"""List document tools.""" # """List document tools."""
return [ # return [
Tool( # Tool(
name="openrag_ingest_file", # name="openrag_ingest_file",
description=( # description=(
"Ingest a local file into the OpenRAG knowledge base. " # "Ingest a local file into the OpenRAG knowledge base. "
"Supported formats: PDF, DOCX, TXT, MD, HTML, and more." # "Supported formats: PDF, DOCX, TXT, MD, HTML, and more."
), # ),
inputSchema={ # inputSchema={
"type": "object", # "type": "object",
"properties": { # "properties": {
"file_path": { # "file_path": {
"type": "string", # "type": "string",
"description": "Absolute path to the file to ingest", # "description": "Absolute path to the file to ingest",
}, # },
}, # },
"required": ["file_path"], # "required": ["file_path"],
}, # },
), # ),
Tool( # Tool(
name="openrag_ingest_url", # name="openrag_ingest_url",
description=( # description=(
"Ingest content from a URL into the OpenRAG knowledge base. " # "Ingest content from a URL into the OpenRAG knowledge base. "
"The URL content will be fetched, processed, and stored." # "The URL content will be fetched, processed, and stored."
), # ),
inputSchema={ # inputSchema={
"type": "object", # "type": "object",
"properties": { # "properties": {
"url": { # "url": {
"type": "string", # "type": "string",
"description": "The URL to fetch and ingest", # "description": "The URL to fetch and ingest",
}, # },
}, # },
"required": ["url"], # "required": ["url"],
}, # },
), # ),
Tool( # Tool(
name="openrag_list_documents", # name="openrag_list_documents",
description="List documents in the OpenRAG knowledge base.", # description="List documents in the OpenRAG knowledge base.",
inputSchema={ # inputSchema={
"type": "object", # "type": "object",
"properties": { # "properties": {
"limit": { # "limit": {
"type": "integer", # "type": "integer",
"description": "Maximum number of documents to return (default: 50)", # "description": "Maximum number of documents to return (default: 50)",
"default": 50, # "default": 50,
}, # },
}, # },
"required": [], # "required": [],
}, # },
), # ),
Tool( # Tool(
name="openrag_delete_document", # name="openrag_delete_document",
description="Delete a document from the OpenRAG knowledge base.", # description="Delete a document from the OpenRAG knowledge base.",
inputSchema={ # inputSchema={
"type": "object", # "type": "object",
"properties": { # "properties": {
"filename": { # "filename": {
"type": "string", # "type": "string",
"description": "Name of the file to delete", # "description": "Name of the file to delete",
}, # },
}, # },
"required": ["filename"], # "required": ["filename"],
}, # },
), # ),
] # ]
@server.call_tool() # @server.call_tool()
async def call_document_tool(name: str, arguments: dict) -> list[TextContent]: # async def call_document_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle document tool calls.""" # """Handle document tool calls."""
if name == "openrag_ingest_file": # if name == "openrag_ingest_file":
return await _ingest_file(arguments) # return await _ingest_file(arguments)
elif name == "openrag_ingest_url": # elif name == "openrag_ingest_url":
return await _ingest_url(arguments) # return await _ingest_url(arguments)
elif name == "openrag_list_documents": # elif name == "openrag_list_documents":
return await _list_documents(arguments) # return await _list_documents(arguments)
elif name == "openrag_delete_document": # elif name == "openrag_delete_document":
return await _delete_document(arguments) # return await _delete_document(arguments)
return [] # return []
async def _ingest_file(arguments: dict) -> list[TextContent]: # async def _ingest_file(arguments: dict) -> list[TextContent]:
"""Ingest a local file into OpenRAG.""" # """Ingest a local file into OpenRAG using the SDK."""
file_path = arguments.get("file_path", "") # file_path = arguments.get("file_path", "")
if not file_path: # if not file_path:
return [TextContent(type="text", text="Error: file_path is required")] # return [TextContent(type="text", text="Error: file_path is required")]
path = Path(file_path) # path = Path(file_path)
if not path.exists(): # if not path.exists():
return [TextContent(type="text", text=f"Error: File not found: {file_path}")] # return [TextContent(type="text", text=f"Error: File not found: {file_path}")]
if not path.is_file(): # if not path.is_file():
return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")] # return [TextContent(type="text", text=f"Error: Path is not a file: {file_path}")]
try: # try:
async with get_client() as client: # client = get_openrag_client()
# Read file and upload # # Use wait=False to return immediately with task_id
with open(path, "rb") as f: # response = await client.documents.ingest(file_path=path, wait=False)
files = {"file": (path.name, f)}
# Remove Content-Type header for multipart upload
headers = dict(client.headers)
headers.pop("Content-Type", None)
response = await client.post( # result = f"Successfully queued '{response.filename or path.name}' for ingestion."
"/api/v1/documents/ingest", # if response.task_id:
files=files, # result += f"\nTask ID: {response.task_id}"
headers=headers,
)
response.raise_for_status()
data = response.json()
task_id = data.get("task_id") # return [TextContent(type="text", text=result)]
filename = data.get("filename", path.name)
result = f"Successfully queued '{filename}' for ingestion." # except AuthenticationError as e:
if task_id: # logger.error(f"Authentication error: {e.message}")
result += f"\nTask ID: {task_id}" # return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# except ValidationError as e:
return [TextContent(type="text", text=result)] # logger.error(f"Validation error: {e.message}")
# return [TextContent(type="text", text=f"Invalid request: {e.message}")]
except Exception as e: # except RateLimitError as e:
logger.error(f"Ingest file error: {e}") # logger.error(f"Rate limit error: {e.message}")
return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")] # return [TextContent(type="text", text=f"Rate limited: {e.message}")]
# except ServerError as e:
# logger.error(f"Server error: {e.message}")
# return [TextContent(type="text", text=f"Server error: {e.message}")]
# except OpenRAGError as e:
# logger.error(f"OpenRAG error: {e.message}")
# return [TextContent(type="text", text=f"Error: {e.message}")]
# except Exception as e:
# logger.error(f"Ingest file error: {e}")
# return [TextContent(type="text", text=f"Error ingesting file: {str(e)}")]
async def _ingest_url(arguments: dict) -> list[TextContent]: # async def _ingest_url(arguments: dict) -> list[TextContent]:
"""Ingest content from a URL into OpenRAG.""" # """Ingest content from a URL into OpenRAG.
url = arguments.get("url", "")
if not url: # Note: This uses the SDK's chat to trigger URL ingestion via the agent.
return [TextContent(type="text", text="Error: url is required")] # """
# url = arguments.get("url", "")
if not url.startswith(("http://", "https://")): # if not url:
return [TextContent(type="text", text="Error: url must start with http:// or https://")] # return [TextContent(type="text", text="Error: url is required")]
try: # if not url.startswith(("http://", "https://")):
# Use chat with a special prompt to trigger URL ingestion via the agent # return [TextContent(type="text", text="Error: url must start with http:// or https://")]
async with get_client() as client:
payload = {
"message": f"Please ingest the content from this URL into the knowledge base: {url}",
"stream": False,
}
response = await client.post("/api/v1/chat", json=payload) # try:
response.raise_for_status() # # Use chat with a special prompt to trigger URL ingestion via the agent
data = response.json() # client = get_openrag_client()
# response = await client.chat.create(
# message=f"Please ingest the content from this URL into the knowledge base: {url}",
# )
result_text = data.get("response", "") # return [TextContent(type="text", text=f"URL ingestion requested.\n\n{response.response}")]
return [TextContent(type="text", text=f"URL ingestion requested.\n\n{result_text}")]
except Exception as e: # except AuthenticationError as e:
logger.error(f"Ingest URL error: {e}") # logger.error(f"Authentication error: {e.message}")
return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")] # return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# except ServerError as e:
# logger.error(f"Server error: {e.message}")
# return [TextContent(type="text", text=f"Server error: {e.message}")]
# except OpenRAGError as e:
# logger.error(f"OpenRAG error: {e.message}")
# return [TextContent(type="text", text=f"Error: {e.message}")]
# except Exception as e:
# logger.error(f"Ingest URL error: {e}")
# return [TextContent(type="text", text=f"Error ingesting URL: {str(e)}")]
async def _list_documents(arguments: dict) -> list[TextContent]: # async def _list_documents(arguments: dict) -> list[TextContent]:
"""List documents in the knowledge base.""" # """List documents in the knowledge base.
limit = arguments.get("limit", 50)
try: # Note: This uses direct HTTP calls as the SDK doesn't yet support listing documents.
async with get_client() as client: # """
response = await client.get("/api/v1/documents", params={"limit": limit}) # limit = arguments.get("limit", 50)
response.raise_for_status()
data = response.json()
documents = data.get("documents", []) # try:
# async with get_client() as client:
# response = await client.get("/api/v1/documents", params={"limit": limit})
# response.raise_for_status()
# data = response.json()
if not documents: # documents = data.get("documents", [])
return [TextContent(type="text", text="No documents found in the knowledge base.")]
output_parts = [f"Found {len(documents)} document(s):\n"] # if not documents:
# return [TextContent(type="text", text="No documents found in the knowledge base.")]
for doc in documents: # output_parts = [f"Found {len(documents)} document(s):\n"]
filename = doc.get("filename", "Unknown")
chunks = doc.get("chunk_count", 0)
created = doc.get("created_at", "")
output_parts.append(f"\n- **{filename}** ({chunks} chunks)") # for doc in documents:
if created: # filename = doc.get("filename", "Unknown")
output_parts.append(f" - Added: {created[:10]}") # chunks = doc.get("chunk_count", 0)
# created = doc.get("created_at", "")
return [TextContent(type="text", text="".join(output_parts))] # output_parts.append(f"\n- **{filename}** ({chunks} chunks)")
# if created:
# output_parts.append(f" - Added: {created[:10]}")
except Exception as e: # return [TextContent(type="text", text="".join(output_parts))]
logger.error(f"List documents error: {e}")
return [TextContent(type="text", text=f"Error listing documents: {str(e)}")] # except Exception as e:
# logger.error(f"List documents error: {e}")
# return [TextContent(type="text", text=f"Error listing documents: {str(e)}")]
async def _delete_document(arguments: dict) -> list[TextContent]: # async def _delete_document(arguments: dict) -> list[TextContent]:
"""Delete a document from the knowledge base.""" # """Delete a document from the knowledge base using the SDK."""
filename = arguments.get("filename", "") # filename = arguments.get("filename", "")
if not filename: # if not filename:
return [TextContent(type="text", text="Error: filename is required")] # return [TextContent(type="text", text="Error: filename is required")]
try: # try:
async with get_client() as client: # client = get_openrag_client()
response = await client.request( # response = await client.documents.delete(filename)
"DELETE",
"/api/v1/documents",
json={"filename": filename},
)
response.raise_for_status()
data = response.json()
deleted_count = data.get("deleted_count", 0) # return [TextContent(
return [TextContent( # type="text",
type="text", # text=f"Successfully deleted '{filename}' ({response.deleted_chunks} chunks removed).",
text=f"Successfully deleted '{filename}' ({deleted_count} chunks removed).", # )]
)]
except Exception as e:
logger.error(f"Delete document error: {e}")
return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]
# except NotFoundError as e:
# logger.error(f"Document not found: {e.message}")
# return [TextContent(type="text", text=f"Document not found: {e.message}")]
# except AuthenticationError as e:
# logger.error(f"Authentication error: {e.message}")
# return [TextContent(type="text", text=f"Authentication error: {e.message}")]
# except ServerError as e:
# logger.error(f"Server error: {e.message}")
# return [TextContent(type="text", text=f"Server error: {e.message}")]
# except OpenRAGError as e:
# logger.error(f"OpenRAG error: {e.message}")
# return [TextContent(type="text", text=f"Error: {e.message}")]
# except Exception as e:
# logger.error(f"Delete document error: {e}")
# return [TextContent(type="text", text=f"Error deleting document: {str(e)}")]

15
sdks/mcp/uv.lock generated
View file

@ -337,12 +337,27 @@ source = { editable = "." }
dependencies = [ dependencies = [
{ name = "httpx" }, { name = "httpx" },
{ name = "mcp" }, { name = "mcp" },
{ name = "openrag-sdk" },
] ]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" }, { name = "httpx", specifier = ">=0.27.0" },
{ name = "mcp", specifier = ">=1.0.0" }, { name = "mcp", specifier = ">=1.0.0" },
{ name = "openrag-sdk", specifier = ">=0.1.0" },
]
[[package]]
name = "openrag-sdk"
version = "0.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/d7/9e/7a10ddb6742417970163d11aed40146bb3340532f17c1099434d8667a4e0/openrag_sdk-0.1.0.tar.gz", hash = "sha256:cf99fddf254c6c72295c73498877c56f68287d33a072ca2171f490de7df8617e", size = 17116 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/74/3c/6962d87b8b4604ef8cf776bf7d4e416a0025d5f08f5400ecc47b35616ae0/openrag_sdk-0.1.0-py3-none-any.whl", hash = "sha256:6b99be26b64a61e2347a4ce8f530b673a119d59bfe11ff5ead6318b9a3fb9dac", size = 15756 },
] ]
[[package]] [[package]]

2663
uv.lock generated

File diff suppressed because it is too large Load diff