diff --git a/src/api/v1/chat.py b/src/api/v1/chat.py index 43f3ba25..418f08ee 100644 --- a/src/api/v1/chat.py +++ b/src/api/v1/chat.py @@ -2,69 +2,142 @@ Public API v1 Chat endpoint. Provides chat functionality with streaming support and conversation history. -Uses API key authentication. Routes through Langflow endpoint. +Uses API key authentication. Routes through Langflow (chat_service.langflow_chat). """ import json from starlette.requests import Request -from starlette.responses import JSONResponse +from starlette.responses import JSONResponse, StreamingResponse from utils.logging_config import get_logger -from api.chat import langflow_endpoint +from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context logger = get_logger(__name__) -def _transform_v1_request_to_internal(data: dict) -> dict: - """Transform v1 API request format to internal Langflow format.""" - return { - "prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt" - "previous_response_id": data.get("chat_id"), # v1 uses "chat_id" - "stream": data.get("stream", False), - "filters": data.get("filters"), - "limit": data.get("limit", 10), - "scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case - "filter_id": data.get("filter_id"), - } +async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): + """ + Transform the raw Langflow streaming format to clean SSE events for v1 API. + + Yields SSE events in the format: + data: {"type": "content", "delta": "..."} + data: {"type": "sources", "sources": [...]} + data: {"type": "done", "chat_id": "..."} + """ + full_text = "" + sources = [] + chat_id = None + + async for chunk in raw_stream: + try: + if isinstance(chunk, bytes): + chunk_str = chunk.decode("utf-8").strip() + else: + chunk_str = str(chunk).strip() + + if not chunk_str: + continue + + chunk_data = json.loads(chunk_str) + + # Extract text delta + delta_text = None + if "delta" in chunk_data: + delta = chunk_data["delta"] + if isinstance(delta, dict): + delta_text = delta.get("content") or delta.get("text") or "" + elif isinstance(delta, str): + delta_text = delta + + if "output_text" in chunk_data and chunk_data["output_text"]: + delta_text = chunk_data["output_text"] + + if delta_text: + full_text += delta_text + yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n" + + # Extract chat_id/response_id + if "id" in chunk_data and chunk_data["id"]: + chat_id = chunk_data["id"] + elif "response_id" in chunk_data and chunk_data["response_id"]: + chat_id = chunk_data["response_id"] + + except json.JSONDecodeError: + if chunk_str and not chunk_str.startswith("{"): + yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n" + full_text += chunk_str + except Exception as e: + logger.warning("Error processing stream chunk", error=str(e)) + continue + + if sources: + yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n" + + yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n" + chat_id_container["chat_id"] = chat_id async def chat_create_endpoint(request: Request, chat_service, session_manager): """ - Send a chat message. Routes to internal Langflow endpoint. + Send a chat message via Langflow. - POST /v1/chat - see internal /langflow endpoint for full documentation. - Transforms v1 format (message, chat_id, score_threshold) to internal format. + POST /v1/chat """ try: data = await request.json() except Exception: - return JSONResponse( - {"error": "Invalid JSON in request body"}, - status_code=400, - ) + return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400) message = data.get("message", "").strip() if not message: - return JSONResponse( - {"error": "Message is required"}, - status_code=400, + return JSONResponse({"error": "Message is required"}, status_code=400) + + stream = data.get("stream", False) + chat_id = data.get("chat_id") + filters = data.get("filters") + limit = data.get("limit", 10) + score_threshold = data.get("score_threshold", 0) + filter_id = data.get("filter_id") + + user = request.state.user + user_id = user.user_id + jwt_token = session_manager.get_effective_jwt_token(user_id, None) + + # Set context variables for search tool + if filters: + set_search_filters(filters) + set_search_limit(limit) + set_score_threshold(score_threshold) + set_auth_context(user_id, jwt_token) + + if stream: + raw_stream = await chat_service.langflow_chat( + prompt=message, + user_id=user_id, + jwt_token=jwt_token, + previous_response_id=chat_id, + stream=True, + filter_id=filter_id, ) - - # Transform v1 request to internal format - internal_data = _transform_v1_request_to_internal(data) - - # Create a new request with transformed body for the internal endpoint - body = json.dumps(internal_data).encode() - - async def receive(): - return {"type": "http.request", "body": body} - - internal_request = Request(request.scope, receive) - - # Copy state attributes individually (state property has no setter) - internal_request.state.user = request.state.user - internal_request.state.jwt_token = getattr(request.state, "jwt_token", None) - - # Call internal Langflow endpoint - return await langflow_endpoint(internal_request, chat_service, session_manager) + chat_id_container = {} + return StreamingResponse( + _transform_stream_to_sse(raw_stream, chat_id_container), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, + ) + else: + result = await chat_service.langflow_chat( + prompt=message, + user_id=user_id, + jwt_token=jwt_token, + previous_response_id=chat_id, + stream=False, + filter_id=filter_id, + ) + # Transform response_id to chat_id for v1 API format + return JSONResponse({ + "response": result.get("response", ""), + "chat_id": result.get("response_id"), + "sources": result.get("sources", []), + }) async def chat_list_endpoint(request: Request, chat_service, session_manager):