diff --git a/api/apps/sdk/websocket.py b/api/apps/sdk/websocket.py new file mode 100644 index 000000000..5faf11066 --- /dev/null +++ b/api/apps/sdk/websocket.py @@ -0,0 +1,248 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +WebSocket SDK API for RAGFlow Streaming Responses + +This module provides WebSocket endpoints following the SDK API pattern, +mirroring the structure of session.py for consistency. +""" + +import logging +import json +from quart import websocket + +from api.db.services.dialog_service import DialogService +from api.db.services.canvas_service import UserCanvasService +from api.db.services.conversation_service import completion as rag_completion +from api.db.services.canvas_service import completion as agent_completion +from api.utils.api_utils import ws_token_required +from common.constants import StatusEnum + + +async def send_ws_error(error_message, code=500): + """Send error message to WebSocket client.""" + error_response = { + "code": code, + "message": error_message, + "data": { + "answer": f"**ERROR**: {error_message}", + "reference": [] + } + } + await websocket.send(json.dumps(error_response, ensure_ascii=False)) + + +async def send_ws_message(data, code=0, message=""): + """Send message to WebSocket client.""" + response = { + "code": code, + "message": message, + "data": data + } + await websocket.send(json.dumps(response, ensure_ascii=False)) + + +@manager.websocket("/chats//completions") # noqa: F821 +@ws_token_required +async def chat_completions_ws(tenant_id, chat_id): + """ + WebSocket endpoint for streaming chat completions. + Follows the same pattern as the HTTP POST /chats//completions endpoint. + """ + # Verify chat ownership + if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): + await send_ws_error(f"You don't own the chat {chat_id}", code=404) + await websocket.close(1008) + return + + logging.info(f"WebSocket chat connection established for chat_id: {chat_id}, tenant: {tenant_id}") + + try: + while True: + message = await websocket.receive() + + try: + req = json.loads(message) + except json.JSONDecodeError as e: + await send_ws_error(f"Invalid JSON format: {str(e)}", code=400) + continue + + question = req.get("question", "") + session_id = req.get("session_id") + stream = req.get("stream", True) + + if question is None: + await send_ws_error("Missing required parameter: question", code=400) + continue + + try: + if stream: + for response_chunk in rag_completion( + tenant_id=tenant_id, + chat_id=chat_id, + question=question, + session_id=session_id, + stream=True, + **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} + ): + if response_chunk.startswith("data:"): + json_str = response_chunk[5:].strip() + try: + response_data = json.loads(json_str) + await websocket.send(json.dumps(response_data, ensure_ascii=False)) + except json.JSONDecodeError: + continue + + logging.info(f"Chat completion streamed successfully for chat_id: {chat_id}") + else: + response = None + for resp in rag_completion( + tenant_id=tenant_id, + chat_id=chat_id, + question=question, + session_id=session_id, + stream=False, + **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} + ): + response = resp + break + + if response: + await send_ws_message(response) + else: + await send_ws_error("No response generated", code=500) + + except Exception as e: + logging.exception(f"Error during chat completion: {str(e)}") + await send_ws_error(str(e)) + + except Exception as e: + logging.exception(f"WebSocket error: {str(e)}") + try: + await send_ws_error(str(e)) + except Exception: + pass + await websocket.close(1011) + + finally: + logging.info(f"WebSocket chat connection closed for chat_id: {chat_id}") + + +@manager.websocket("/agents//completions") # noqa: F821 +@ws_token_required +async def agent_completions_ws(tenant_id, agent_id): + """ + WebSocket endpoint for streaming agent completions. + Follows the same pattern as the HTTP POST /agents//completions endpoint. + """ + # Verify agent ownership + if not UserCanvasService.query(user_id=tenant_id, id=agent_id): + await send_ws_error(f"You don't own the agent {agent_id}", code=404) + await websocket.close(1008) + return + + logging.info(f"WebSocket agent connection established for agent_id: {agent_id}, tenant: {tenant_id}") + + try: + while True: + message = await websocket.receive() + + try: + req = json.loads(message) + except json.JSONDecodeError as e: + await send_ws_error(f"Invalid JSON format: {str(e)}", code=400) + continue + + question = req.get("question", "") + session_id = req.get("session_id") + stream = req.get("stream", True) + + if not question: + await send_ws_error("Missing required parameter: question", code=400) + continue + + try: + if stream: + async for response_chunk in agent_completion( + tenant_id=tenant_id, + agent_id=agent_id, + question=question, + session_id=session_id, + stream=True, + **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} + ): + if isinstance(response_chunk, str) and response_chunk.startswith("data:"): + json_str = response_chunk[5:].strip() + try: + response_data = json.loads(json_str) + if response_data.get("event") in ["message", "message_end"]: + await websocket.send(json.dumps({ + "code": 0, + "message": "", + "data": response_data + }, ensure_ascii=False)) + except json.JSONDecodeError: + continue + + await send_ws_message(True) + logging.info(f"Agent completion streamed successfully for agent_id: {agent_id}") + else: + full_content = "" + reference = {} + final_ans = None + + async for response_chunk in agent_completion( + tenant_id=tenant_id, + agent_id=agent_id, + question=question, + session_id=session_id, + stream=False, + **{k: v for k, v in req.items() if k not in ["question", "session_id", "stream"]} + ): + if isinstance(response_chunk, str) and response_chunk.startswith("data:"): + try: + ans = json.loads(response_chunk[5:]) + if ans["event"] == "message": + full_content += ans["data"]["content"] + if ans.get("data", {}).get("reference", None): + reference.update(ans["data"]["reference"]) + final_ans = ans + except Exception as e: + await send_ws_error(str(e)) + continue + + if final_ans: + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + await send_ws_message(final_ans) + else: + await send_ws_error("No response generated", code=500) + + except Exception as e: + logging.exception(f"Error during agent completion: {str(e)}") + await send_ws_error(str(e)) + + except Exception as e: + logging.exception(f"WebSocket error: {str(e)}") + try: + await send_ws_error(str(e)) + except Exception: + pass + await websocket.close(1011) + + finally: + logging.info(f"WebSocket agent connection closed for agent_id: {agent_id}") + diff --git a/api/apps/websocket_app.py b/api/apps/websocket_app.py deleted file mode 100644 index 72eaeb52e..000000000 --- a/api/apps/websocket_app.py +++ /dev/null @@ -1,709 +0,0 @@ -# -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -WebSocket API for RAGFlow Streaming Responses - -This module provides WebSocket endpoints for real-time streaming of chat completions. -WebSocket support is essential for platforms like WeChat Mini Programs that require -persistent bidirectional connections for real-time communication. - -Key Features: -- Real-time bidirectional communication via WebSocket -- Support for multiple authentication methods (API Token, User Session) -- Streaming chat completions with incremental responses -- Error handling and connection management -- Compatible with WeChat Mini Programs and other WebSocket clients - -WebSocket Message Format: - Client -> Server (Request): - { - "type": "chat", # Message type (currently supports "chat") - "chat_id": "xxx", # Dialog/Chat ID - "session_id": "xxx", # Optional: Conversation session ID - "question": "Hello", # User's question/message - "stream": true, # Optional: Enable streaming (default: true) - "kb_ids": [] # Optional: Knowledge base IDs to query - } - - Server -> Client (Response): - { - "code": 0, # Status code (0=success, 500=error) - "message": "", # Error message (if any) - "data": { # Response data - "answer": "...", # Incremental answer text (for streaming) - "reference": {...}, # Source references - "id": "xxx", # Message ID - "session_id": "xxx" # Session ID - } - } - - Server -> Client (Completion): - { - "code": 0, - "message": "", - "data": true # Indicates completion of streaming - } - - Server -> Client (Error): - { - "code": 500, - "message": "Error description", - "data": { - "answer": "**ERROR**: Error details", - "reference": [] - } - } - -Connection Lifecycle: -1. Client initiates WebSocket connection with authentication -2. Server validates authentication (API token or user session) -3. Client sends chat message requests -4. Server streams response chunks back to client -5. Server sends completion marker when done -6. Connection remains open for subsequent messages -7. Either party can close the connection -""" - -import logging -import json -from quart import websocket -from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer - -from api.db.db_models import APIToken -from api.db.services.user_service import UserService -from api.db.services.dialog_service import DialogService -from api.db.services.conversation_service import completion -from common.constants import StatusEnum -from common import settings - - -# ----------------------------------------------------------------------------- -# Authentication Helper Functions -# ----------------------------------------------------------------------------- - -async def authenticate_websocket(): - """ - Authenticate WebSocket connection using multiple methods. - - This function attempts to authenticate the WebSocket connection using: - 1. API Token authentication (Bearer token in Authorization header) - 2. User session authentication (Session-based JWT token) - 3. Query parameter authentication (token passed as URL parameter) - - Authentication Methods: - - API Token: Used by external applications, bots, and integrations - - User Session: Used by web interface and logged-in users - - Query Parameter: Fallback for clients that can't send headers - - Returns: - tuple: (authenticated: bool, tenant_id: str|None, error_message: str|None) - - Examples: - # API Token authentication - ws://host/ws/chat?Authorization=Bearer ragflow-xxxxx - - # Query parameter authentication - ws://host/ws/chat?token=ragflow-xxxxx - """ - tenant_id = None - error_message = None - - # Method 1: Try API Token authentication from Authorization header - # This is the preferred method for SDK and API integrations - authorization = websocket.headers.get("Authorization", "") - - if authorization: - try: - # Parse Bearer token format: "Bearer " - authorization_parts = authorization.split() - - if len(authorization_parts) >= 2: - token = authorization_parts[1] - - # Query database for matching API token - objs = APIToken.query(token=token) - - if objs: - # Valid API token found, extract tenant ID - tenant_id = objs[0].tenant_id - logging.info(f"WebSocket authenticated via API token for tenant: {tenant_id}") - return True, tenant_id, None - else: - error_message = "Invalid API token" - logging.warning(f"WebSocket authentication failed: {error_message}") - else: - error_message = "Invalid Authorization header format. Expected: 'Bearer '" - logging.warning(f"WebSocket authentication failed: {error_message}") - - except Exception as e: - error_message = f"Error processing API token: {str(e)}" - logging.error(f"WebSocket authentication error: {error_message}") - - # Method 2: Try User Session authentication (JWT token) - # This is used by the web interface for logged-in users - try: - jwt = Serializer(secret_key=settings.SECRET_KEY) - - # Try to get authorization from header or query parameter - auth_token = websocket.headers.get("Authorization") or \ - websocket.args.get("authorization") or \ - websocket.args.get("token") - - if auth_token: - try: - # Decode JWT token to get access token - access_token = str(jwt.loads(auth_token)) - - # Validate access token format - if access_token and len(access_token.strip()) >= 32: - # Query user by access token - user = UserService.query( - access_token=access_token, - status=StatusEnum.VALID.value - ) - - if user and user[0]: - # Valid user session found - tenant_id = user[0].id - logging.info(f"WebSocket authenticated via user session for user: {user[0].email}") - return True, tenant_id, None - - except Exception as e: - # JWT decoding or validation failed - logging.debug(f"User session authentication failed: {str(e)}") - - except Exception as e: - logging.error(f"Error in user session authentication: {str(e)}") - - # Method 3: Try query parameter authentication - # Fallback for clients that cannot set custom headers - token_param = websocket.args.get("token") - if token_param: - try: - objs = APIToken.query(token=token_param) - if objs: - tenant_id = objs[0].tenant_id - logging.info(f"WebSocket authenticated via query parameter for tenant: {tenant_id}") - return True, tenant_id, None - except Exception as e: - logging.error(f"Query parameter authentication error: {str(e)}") - - # No valid authentication method succeeded - if not error_message: - error_message = "Authentication required. Please provide valid API token or user session." - - return False, None, error_message - - -async def send_error(error_message, code=500): - """ - Send error message to WebSocket client in standardized format. - - Args: - error_message (str): Human-readable error description - code (int): Error code (default: 500 for server errors) - - Error Response Format: - { - "code": 500, - "message": "Error description", - "data": { - "answer": "**ERROR**: Error details", - "reference": [] - } - } - """ - error_response = { - "code": code, - "message": error_message, - "data": { - "answer": f"**ERROR**: {error_message}", - "reference": [] - } - } - - await websocket.send(json.dumps(error_response, ensure_ascii=False)) - logging.error(f"WebSocket error sent: {error_message}") - - -async def send_message(data, code=0, message=""): - """ - Send message to WebSocket client in standardized format. - - Args: - data: Response data (can be dict, bool, or any JSON-serializable object) - code (int): Status code (0 for success) - message (str): Optional status message - - Success Response Format: - { - "code": 0, - "message": "", - "data": {...} - } - """ - response = { - "code": code, - "message": message, - "data": data - } - - await websocket.send(json.dumps(response, ensure_ascii=False)) - - -# ----------------------------------------------------------------------------- -# WebSocket Endpoint: Chat Completions -# ----------------------------------------------------------------------------- - -@manager.route("/ws/chat") # noqa: F821 -async def websocket_chat(): - """ - WebSocket endpoint for real-time chat completions with streaming responses. - - This endpoint provides a persistent WebSocket connection for interactive chat - sessions. It supports streaming responses, allowing clients to receive - incremental updates as the AI generates the response. - - Connection URL: - ws://host/v1/ws/chat - - Authentication: - - Authorization header: "Bearer " - - Query parameter: "?token=" - - User session JWT - - Message Flow: - 1. Client connects and authenticates - 2. Client sends chat request message - 3. Server streams response chunks - 4. Server sends completion marker - 5. Connection stays open for more messages - - Supported Features: - - Multi-turn conversations with session tracking - - Knowledge base integration for RAG - - Reference/citation tracking - - Error recovery and graceful degradation - - Example Client Code (JavaScript): - ```javascript - const ws = new WebSocket('ws://host/v1/ws/chat?token=YOUR_TOKEN'); - - ws.onopen = () => { - ws.send(JSON.stringify({ - type: 'chat', - chat_id: 'your-chat-id', - question: 'Hello, how are you?', - stream: true - })); - }; - - ws.onmessage = (event) => { - const response = JSON.parse(event.data); - if (response.data === true) { - console.log('Stream completed'); - } else { - console.log('Received:', response.data.answer); - } - }; - ``` - - Example Client Code (Python): - ```python - import websocket - import json - - def on_message(ws, message): - data = json.loads(message) - if data['data'] is True: - print('Stream completed') - else: - print('Received:', data['data']['answer']) - - ws = websocket.WebSocketApp( - 'ws://host/v1/ws/chat?token=YOUR_TOKEN', - on_message=on_message - ) - - ws.on_open = lambda ws: ws.send(json.dumps({ - 'type': 'chat', - 'chat_id': 'your-chat-id', - 'question': 'Hello!', - 'stream': True - })) - - ws.run_forever() - ``` - """ - # Step 1: Authenticate the WebSocket connection - # This ensures only authorized clients can access the chat service - authenticated, tenant_id, error_msg = await authenticate_websocket() - - if not authenticated: - # Authentication failed - send error and close connection - await send_error(error_msg, code=401) - await websocket.close(1008, error_msg) # 1008 = Policy Violation - return - - # Authentication successful - log connection - logging.info(f"WebSocket chat connection established for tenant: {tenant_id}") - - # Step 2: Connection loop - handle multiple messages over same connection - # WebSocket connections are persistent, allowing multiple request/response cycles - try: - # Keep connection open and process incoming messages - while True: - # Wait for message from client - # This is a blocking call that waits until client sends data - message = await websocket.receive() - - # Parse JSON message from client - try: - request_data = json.loads(message) - except json.JSONDecodeError as e: - # Invalid JSON format - send error but keep connection open - await send_error(f"Invalid JSON format: {str(e)}", code=400) - continue - - # Extract message type (currently only 'chat' is supported) - message_type = request_data.get("type", "chat") - - # Step 3: Route message to appropriate handler based on type - if message_type == "chat": - # Handle chat completion request - await handle_chat_request(tenant_id, request_data) - else: - # Unknown message type - send error but keep connection open - await send_error(f"Unknown message type: {message_type}", code=400) - - except Exception as e: - # Unexpected error occurred - log and notify client - error_message = f"WebSocket error: {str(e)}" - logging.exception(error_message) - - try: - await send_error(error_message) - except Exception: - # Failed to send error (connection may be closed) - pass - - # Close connection with error code - await websocket.close(1011, "Internal server error") # 1011 = Internal Error - - finally: - # Connection closed - cleanup and log - logging.info(f"WebSocket chat connection closed for tenant: {tenant_id}") - - -async def handle_chat_request(tenant_id, request_data): - """ - Handle chat completion request received via WebSocket. - - This function processes a chat request, validates parameters, retrieves - the dialog configuration, and streams the AI response back to the client. - - Args: - tenant_id (str): Authenticated tenant/user ID - request_data (dict): Parsed JSON request from client - - Required Request Fields: - - chat_id (str): Dialog/Chat ID to use for the conversation - - question (str): User's question or message - - Optional Request Fields: - - session_id (str): Existing conversation session ID (creates new if not provided) - - stream (bool): Enable streaming responses (default: True) - - kb_ids (list): Knowledge base IDs to include in retrieval - - doc_ids (str): Comma-separated document IDs to prioritize - - files (list): File IDs attached to this message - - Processing Steps: - 1. Validate required parameters - 2. Verify dialog ownership and permissions - 3. Create or retrieve conversation session - 4. Stream AI-generated response chunks - 5. Send completion marker - - Error Handling: - - Missing parameters: Returns 400 error - - Invalid dialog: Returns 404 error - - Permission denied: Returns 403 error - - Processing error: Returns 500 error - """ - try: - # Step 1: Extract and validate required parameters - chat_id = request_data.get("chat_id") - question = request_data.get("question", "") - session_id = request_data.get("session_id") - stream = request_data.get("stream", True) - - # Validate chat_id is provided - if not chat_id: - await send_error("Missing required parameter: chat_id", code=400) - return - - # Validate question is provided (empty questions are allowed for session initialization) - if question is None: - await send_error("Missing required parameter: question", code=400) - return - - # Step 2: Verify dialog exists and user has access - # Check if the dialog belongs to this tenant and is active - dialog_query = DialogService.query( - tenant_id=tenant_id, - id=chat_id, - status=StatusEnum.VALID.value - ) - - if not dialog_query: - # Dialog not found or user doesn't have permission - await send_error(f"Dialog not found or access denied: {chat_id}", code=404) - return - - # Step 3: Extract optional parameters for enhanced functionality - # These parameters customize the retrieval and generation process - additional_params = {} - - # Knowledge base filtering - limit search to specific KBs - if "kb_ids" in request_data: - additional_params["kb_ids"] = request_data["kb_ids"] - - # Document filtering - prioritize specific documents - if "doc_ids" in request_data: - additional_params["doc_ids"] = request_data["doc_ids"] - - # File attachments - include files uploaded with this message - if "files" in request_data: - additional_params["files"] = request_data["files"] - - # Pass through any other custom parameters - # This allows for future extensibility without code changes - for key, value in request_data.items(): - if key not in ["type", "chat_id", "question", "session_id", "stream"]: - if key not in additional_params: - additional_params[key] = value - - # Step 4: Process chat completion with streaming - if stream: - # Streaming mode: Send incremental response chunks - # This provides a better user experience with real-time feedback - - try: - # Call the completion service which yields response chunks - # The completion function handles session management, RAG retrieval, - # LLM generation, and response formatting - # Note: completion() is a synchronous generator, not async - for response_chunk in completion( - tenant_id=tenant_id, - chat_id=chat_id, - question=question, - session_id=session_id, - stream=True, - **additional_params - ): - # Parse the SSE-formatted response - # completion() returns "data:{json}\n\n" format for compatibility - if response_chunk.startswith("data:"): - # Extract JSON from SSE format - json_str = response_chunk[5:].strip() - - # Parse and forward to WebSocket client - try: - response_data = json.loads(json_str) - - # Send the chunk to WebSocket client - await websocket.send(json.dumps(response_data, ensure_ascii=False)) - - except json.JSONDecodeError: - # Malformed response chunk - log but continue - logging.warning(f"Failed to parse response chunk: {json_str}") - continue - - # Stream completed successfully - logging.info(f"Chat completion streamed successfully for chat_id: {chat_id}") - - except Exception as e: - # Error during streaming - send error message - error_message = f"Error during chat completion: {str(e)}" - logging.exception(error_message) - await send_error(error_message) - - else: - # Non-streaming mode: Send complete response at once - # This is simpler but provides no incremental feedback - - try: - # Get the complete response (completion yields once for non-streaming) - response = None - for resp in completion( - tenant_id=tenant_id, - chat_id=chat_id, - question=question, - session_id=session_id, - stream=False, - **additional_params - ): - response = resp - break # Only one response in non-streaming mode - - # Send complete response - if response: - await send_message(response) - else: - await send_error("No response generated", code=500) - - logging.info(f"Chat completion completed (non-streaming) for chat_id: {chat_id}") - - except Exception as e: - # Error during generation - send error message - error_message = f"Error during chat completion: {str(e)}" - logging.exception(error_message) - await send_error(error_message) - - except Exception as e: - # Unexpected error in request handling - error_message = f"Error handling chat request: {str(e)}" - logging.exception(error_message) - await send_error(error_message) - - -# ----------------------------------------------------------------------------- -# WebSocket Endpoint: Agent Completions (Future Enhancement) -# ----------------------------------------------------------------------------- - -@manager.route("/ws/agent") # noqa: F821 -async def websocket_agent(): - """ - WebSocket endpoint for agent-based completions with streaming. - - This endpoint is similar to websocket_chat but designed for agent-based - interactions. Agents can have custom tools, workflows, and behaviors - beyond standard RAG chat. - - Note: This is a placeholder for future implementation. The authentication - and connection handling logic is the same as websocket_chat. - - Future Enhancements: - - Tool calling and function execution - - Multi-step agent reasoning - - Agent state management - - Custom agent workflows - """ - # Authenticate connection - authenticated, tenant_id, error_msg = await authenticate_websocket() - - if not authenticated: - await send_error(error_msg, code=401) - await websocket.close(1008, error_msg) - return - - logging.info(f"WebSocket agent connection established for tenant: {tenant_id}") - - # Connection loop - try: - while True: - message = await websocket.receive() - - try: - request_data = json.loads(message) - except json.JSONDecodeError as e: - await send_error(f"Invalid JSON format: {str(e)}", code=400) - continue - - # Handle agent completion request - await handle_agent_request(tenant_id, request_data) - - except Exception as e: - error_message = f"WebSocket error: {str(e)}" - logging.exception(error_message) - - try: - await send_error(error_message) - except Exception: - pass - - await websocket.close(1011, "Internal server error") - - finally: - logging.info(f"WebSocket agent connection closed for tenant: {tenant_id}") - - -async def handle_agent_request(tenant_id, request_data): - """ - Handle agent completion request received via WebSocket. - - This is a placeholder for future agent functionality. - - Args: - tenant_id (str): Authenticated tenant/user ID - request_data (dict): Parsed JSON request from client - """ - # TODO: Implement agent-specific logic - # For now, return a not-implemented error - await send_error("Agent completions not yet implemented", code=501) - - logging.info("Agent request received but not yet implemented") - - -# ----------------------------------------------------------------------------- -# WebSocket Health Check Endpoint -# ----------------------------------------------------------------------------- - -@manager.route("/ws/health") # noqa: F821 -async def websocket_health(): - """ - WebSocket health check endpoint. - - This endpoint allows clients to verify WebSocket connectivity - without authentication. Useful for monitoring and diagnostics. - - The server will echo back any messages received, allowing clients - to test round-trip latency and connection stability. - - Example Usage: - ```javascript - const ws = new WebSocket('ws://host/v1/ws/health'); - ws.onopen = () => ws.send('ping'); - ws.onmessage = (e) => console.log('Received:', e.data); - ``` - """ - logging.info("WebSocket health check connection established") - - try: - # Send initial health status - await websocket.send(json.dumps({ - "status": "healthy", - "message": "WebSocket connection established", - "version": "1.0" - })) - - # Echo messages back to client - while True: - message = await websocket.receive() - - # Echo the message back - await websocket.send(json.dumps({ - "echo": message, - "timestamp": str(logging.time.time()) - })) - - except Exception as e: - logging.info(f"WebSocket health check closed: {str(e)}") - - finally: - logging.info("WebSocket health check connection closed") - diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 8f17e1de0..c8eb578aa 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -283,6 +283,93 @@ def token_required(func): return decorated_function +def ws_token_required(func): + """ + WebSocket authentication decorator for SDK endpoints. + Follows the same pattern as token_required but for WebSocket connections. + """ + from quart import websocket + from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer + from api.db.services.user_service import UserService + from common.constants import StatusEnum + + async def get_tenant_id_from_websocket(**kwargs): + """Extract tenant_id from WebSocket authentication.""" + # Method 1: Try API Token authentication from Authorization header + authorization = websocket.headers.get("Authorization", "") + + if authorization: + try: + authorization_parts = authorization.split() + if len(authorization_parts) >= 2: + token = authorization_parts[1] + objs = APIToken.query(token=token) + if objs: + kwargs["tenant_id"] = objs[0].tenant_id + logging.info(f"WebSocket authenticated via API token") + return True, kwargs + except Exception as e: + logging.error(f"WebSocket API token auth error: {str(e)}") + + # Method 2: Try User Session authentication (JWT) + try: + jwt = Serializer(secret_key=settings.SECRET_KEY) + auth_token = websocket.headers.get("Authorization") or \ + websocket.args.get("authorization") or \ + websocket.args.get("token") + + if auth_token: + try: + if auth_token.startswith("Bearer "): + auth_token = auth_token[7:] + access_token = str(jwt.loads(auth_token)) + if access_token and len(access_token.strip()) >= 32: + user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) + if user and user[0]: + kwargs["tenant_id"] = user[0].id + logging.info(f"WebSocket authenticated via user session") + return True, kwargs + except Exception: + pass + except Exception: + pass + + # Method 3: Try query parameter authentication + token_param = websocket.args.get("token") + if token_param: + try: + objs = APIToken.query(token=token_param) + if objs: + kwargs["tenant_id"] = objs[0].tenant_id + logging.info(f"WebSocket authenticated via query parameter") + return True, kwargs + except Exception: + pass + + return False, "Authentication required. Please provide valid API token or user session." + + @wraps(func) + async def adecorated_function(*args, **kwargs): + """Async wrapper for WebSocket endpoint.""" + success, result = await get_tenant_id_from_websocket(**kwargs) + + if not success: + # Authentication failed - send error and close connection + error_response = { + "code": RetCode.AUTHENTICATION_ERROR, + "message": result, + "data": {"answer": f"**ERROR**: {result}", "reference": []} + } + await websocket.send(json.dumps(error_response, ensure_ascii=False)) + await websocket.close(1008, result) # 1008 = Policy Violation + return + + # Authentication successful - call the actual handler + return await func(*args, **result) + + return adecorated_function + + def get_result(code=RetCode.SUCCESS, message="", data=None, total=None): """ Standard API response format: diff --git a/pyproject.toml b/pyproject.toml index 9da1236c0..e63ff71b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,11 +163,13 @@ test = [ "openpyxl>=3.1.5", "pillow>=10.4.0", "pytest>=8.3.5", + "pytest-asyncio>=0.24.0", "python-docx>=1.1.2", "python-pptx>=1.0.2", "reportlab>=4.4.1", "requests>=2.32.2", "requests-toolbelt>=1.0.0", + "websockets>=14.0", ] [[tool.uv.index]]