diff --git a/rag/utils/mcp_tool_call_conn.py b/rag/utils/mcp_tool_call_conn.py index 93d678ae6..2093f7bc8 100644 --- a/rag/utils/mcp_tool_call_conn.py +++ b/rag/utils/mcp_tool_call_conn.py @@ -61,7 +61,8 @@ class MCPToolCallSession(ToolCallSession): for h, v in raw_headers.items(): nh = Template(h).safe_substitute(self._server_variables) nv = Template(v).safe_substitute(self._server_variables) - headers[nh] = nv + if nh.strip() and nv.strip().strip("Bearer"): + headers[nh] = nv if self._mcp_server.server_type == MCPServerType.SSE: # SSE transport @@ -76,6 +77,9 @@ class MCPToolCallSession(ToolCallSession): msg = f"Timeout initializing client_session for server {self._mcp_server.id}" logging.error(msg) await self._process_mcp_tasks(None, msg) + except asyncio.CancelledError: + logging.warning(f"SSE transport MCP session cancelled for server {self._mcp_server.id}") + return except Exception: msg = "Connection failed (possibly due to auth error). Please check authentication settings first" await self._process_mcp_tasks(None, msg) @@ -93,6 +97,9 @@ class MCPToolCallSession(ToolCallSession): msg = f"Timeout initializing client_session for server {self._mcp_server.id}" logging.error(msg) await self._process_mcp_tasks(None, msg) + except asyncio.CancelledError: + logging.warning(f"STREAMABLE_HTTP MCP session cancelled for server {self._mcp_server.id}") + return except Exception as e: logging.exception(e) msg = "Connection failed (possibly due to auth error). Please check authentication settings first" @@ -107,6 +114,8 @@ class MCPToolCallSession(ToolCallSession): mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1) except asyncio.TimeoutError: continue + except asyncio.CancelledError: + break logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") @@ -114,7 +123,10 @@ class MCPToolCallSession(ToolCallSession): if not client_session or error_message: r = ValueError(error_message) - await result_queue.put(r) + try: + await result_queue.put(r) + except asyncio.CancelledError: + break continue try: @@ -126,10 +138,18 @@ class MCPToolCallSession(ToolCallSession): r = ValueError(f"Unknown MCP task {mcp_task}") except Exception as e: r = e + except asyncio.CancelledError: + break - await result_queue.put(r) + try: + await result_queue.put(r) + except asyncio.CancelledError: + break async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float | int = 8, **kwargs) -> Any: + if self._close: + raise ValueError("Session is closed") + results = asyncio.Queue() await self._queue.put((task_type, kwargs, results)) @@ -163,6 +183,9 @@ class MCPToolCallSession(ToolCallSession): raise def get_tools(self, timeout: float | int = 10) -> list[Tool]: + if self._close: + raise ValueError("Session is closed") + future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(timeout=timeout), self._event_loop) try: return future.result(timeout=timeout) @@ -176,6 +199,9 @@ class MCPToolCallSession(ToolCallSession): @override def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: + if self._close: + return "Error: Session is closed" + future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop) try: return future.result(timeout=timeout) @@ -191,8 +217,29 @@ class MCPToolCallSession(ToolCallSession): return self._close = True - self._event_loop.call_soon_threadsafe(self._event_loop.stop) - self._thread_pool.shutdown(wait=True) + + while not self._queue.empty(): + try: + _, _, result_queue = self._queue.get_nowait() + try: + await result_queue.put(asyncio.CancelledError("Session is closing")) + except Exception: + pass + except asyncio.QueueEmpty: + break + except Exception: + break + + try: + self._event_loop.call_soon_threadsafe(self._event_loop.stop) + except Exception: + pass + + try: + self._thread_pool.shutdown(wait=True) + except Exception: + pass + self.__class__._ALL_INSTANCES.discard(self) def close_sync(self, timeout: float | int = 5) -> None: @@ -200,13 +247,16 @@ class MCPToolCallSession(ToolCallSession): logging.warning(f"Event loop already stopped for {self._mcp_server.id}") return - future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) try: - future.result(timeout=timeout) - except FuturesTimeoutError: - logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})") + future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) + try: + future.result(timeout=timeout) + except FuturesTimeoutError: + logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})") + except Exception: + logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") except Exception: - logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") + logging.exception(f"Exception while scheduling close for server {self._mcp_server.id}") def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: @@ -215,16 +265,24 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> async def _gather_and_stop() -> None: try: await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True) + except Exception: + logging.exception("Exception during MCP session cleanup") finally: - loop.call_soon_threadsafe(loop.stop) + try: + loop.call_soon_threadsafe(loop.stop) + except Exception: + pass - loop = asyncio.new_event_loop() - thread = threading.Thread(target=loop.run_forever, daemon=True) - thread.start() + try: + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() - asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result() + asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result() + thread.join() + except Exception: + logging.exception("Exception during MCP session cleanup thread management") - thread.join() logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")