# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ WebSocket SDK API for RAGFlow Streaming Responses This module provides WebSocket endpoints following the SDK API pattern, mirroring the structure of session.py for consistency. """ import logging import json from quart import websocket from api.db.services.dialog_service import DialogService from api.db.services.canvas_service import UserCanvasService from api.db.services.conversation_service import async_completion as rag_completion from api.db.services.canvas_service import completion as agent_completion from api.utils.api_utils import ws_token_required from common.constants import StatusEnum async def send_ws_error(error_message, code=500): """Send error message to WebSocket client.""" error_response = { "code": code, "message": error_message, "data": { "answer": f"**ERROR**: {error_message}", "reference": [] } } await websocket.send(json.dumps(error_response, ensure_ascii=False)) async def send_ws_message(data, code=0, message=""): """Send message to WebSocket client.""" response = { "code": code, "message": message, "data": data } await websocket.send(json.dumps(response, ensure_ascii=False)) @manager.websocket("/ws/chats//completions") # noqa: F821 @ws_token_required async def chat_completions_ws(tenant_id, chat_id): """ WebSocket endpoint for streaming chat completions. Follows the same pattern as the HTTP POST /chats//completions endpoint. Uses /ws/ prefix to avoid routing conflicts with HTTP endpoints. """ # Verify chat ownership if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): await send_ws_error(f"You don't own the chat {chat_id}", code=404) await websocket.close(1008) return logging.info(f"WebSocket chat connection established for chat_id: {chat_id}, tenant: {tenant_id}") try: while True: message = await websocket.receive() try: req = json.loads(message) except json.JSONDecodeError as e: await send_ws_error(f"Invalid JSON format: {str(e)}", code=400) continue question = req.get("question", "") session_id = req.get("session_id") stream = req.get("stream", True) if question is None: await send_ws_error("Missing required parameter: question", code=400) continue try: if stream: async for response_chunk in rag_completion( tenant_id=tenant_id, chat_id=chat_id, question=question, session_id=session_id, stream=True, **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} ): if response_chunk.startswith("data:"): json_str = response_chunk[5:].strip() try: response_data = json.loads(json_str) await websocket.send(json.dumps(response_data, ensure_ascii=False)) except json.JSONDecodeError: continue logging.info(f"Chat completion streamed successfully for chat_id: {chat_id}") else: response = None async for resp in rag_completion( tenant_id=tenant_id, chat_id=chat_id, question=question, session_id=session_id, stream=False, **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} ): response = resp break if response: await send_ws_message(response) else: await send_ws_error("No response generated", code=500) except Exception as e: logging.exception(f"Error during chat completion: {str(e)}") await send_ws_error(str(e)) except Exception as e: logging.exception(f"WebSocket error: {str(e)}") try: await send_ws_error(str(e)) except Exception: pass await websocket.close(1011) finally: logging.info(f"WebSocket chat connection closed for chat_id: {chat_id}") @manager.websocket("/ws/agents//completions") # noqa: F821 @ws_token_required async def agent_completions_ws(tenant_id, agent_id): """ WebSocket endpoint for streaming agent completions. Follows the same pattern as the HTTP POST /agents//completions endpoint. Uses /ws/ prefix to avoid routing conflicts with HTTP endpoints. """ # Verify agent ownership if not UserCanvasService.query(user_id=tenant_id, id=agent_id): await send_ws_error(f"You don't own the agent {agent_id}", code=404) await websocket.close(1008) return logging.info(f"WebSocket agent connection established for agent_id: {agent_id}, tenant: {tenant_id}") try: while True: message = await websocket.receive() try: req = json.loads(message) except json.JSONDecodeError as e: await send_ws_error(f"Invalid JSON format: {str(e)}", code=400) continue question = req.get("question", "") session_id = req.get("session_id") stream = req.get("stream", True) if not question: await send_ws_error("Missing required parameter: question", code=400) continue try: if stream: async for response_chunk in agent_completion( tenant_id=tenant_id, agent_id=agent_id, question=question, session_id=session_id, stream=True, **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} ): if isinstance(response_chunk, str) and response_chunk.startswith("data:"): json_str = response_chunk[5:].strip() try: response_data = json.loads(json_str) if response_data.get("event") in ["message", "message_end"]: await websocket.send(json.dumps({ "code": 0, "message": "", "data": response_data }, ensure_ascii=False)) except json.JSONDecodeError: continue await send_ws_message(True) logging.info(f"Agent completion streamed successfully for agent_id: {agent_id}") else: full_content = "" reference = {} final_ans = None async for response_chunk in agent_completion( tenant_id=tenant_id, agent_id=agent_id, question=question, session_id=session_id, stream=False, **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} ): if isinstance(response_chunk, str) and response_chunk.startswith("data:"): try: ans = json.loads(response_chunk[5:]) if ans["event"] == "message": full_content += ans["data"]["content"] if ans.get("data", {}).get("reference", None): reference.update(ans["data"]["reference"]) final_ans = ans except Exception as e: await send_ws_error(str(e)) continue if final_ans: final_ans["data"]["content"] = full_content final_ans["data"]["reference"] = reference await send_ws_message(final_ans) else: await send_ws_error("No response generated", code=500) except Exception as e: logging.exception(f"Error during agent completion: {str(e)}") await send_ws_error(str(e)) except Exception as e: logging.exception(f"WebSocket error: {str(e)}") try: await send_ws_error(str(e)) except Exception: pass await websocket.close(1011) finally: logging.info(f"WebSocket agent connection closed for agent_id: {agent_id}")