diff --git a/src/api/v1/chat.py b/src/api/v1/chat.py index 4a8d3e3d..ecadba16 100644 --- a/src/api/v1/chat.py +++ b/src/api/v1/chat.py @@ -2,7 +2,7 @@ Public API v1 Chat endpoint. Provides chat functionality with streaming support and conversation history. -Uses API key authentication. +Uses API key authentication. Routes through Langflow (chat_service.langflow_chat). """ import json from starlette.requests import Request @@ -15,25 +15,18 @@ logger = get_logger(__name__) async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): """ - Transform the raw internal streaming format to clean SSE events. + Transform the raw Langflow streaming format to clean SSE events for v1 API. Yields SSE events in the format: - event: content data: {"type": "content", "delta": "..."} - - event: sources data: {"type": "sources", "sources": [...]} - - event: done data: {"type": "done", "chat_id": "..."} """ full_text = "" - sources = [] chat_id = None async for chunk in raw_stream: try: - # Decode the chunk if isinstance(chunk, bytes): chunk_str = chunk.decode("utf-8").strip() else: @@ -42,131 +35,76 @@ async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): if not chunk_str: continue - # Parse the JSON chunk chunk_data = json.loads(chunk_str) - # Extract text delta - delta_text = None + # Extract text from various possible formats + delta_text = "" + + # Format 1: delta.content (OpenAI-style) if "delta" in chunk_data: delta = chunk_data["delta"] if isinstance(delta, dict): - delta_text = delta.get("content") or delta.get("text") or "" + delta_text = delta.get("content", "") or delta.get("text", "") elif isinstance(delta, str): delta_text = delta - if "output_text" in chunk_data and chunk_data["output_text"]: + # Format 2: output_text (Langflow-style) + if not delta_text and chunk_data.get("output_text"): delta_text = chunk_data["output_text"] - # Yield content event if we have text + # Format 3: text field directly + if not delta_text and chunk_data.get("text"): + delta_text = chunk_data["text"] + + # Format 4: content field directly + if not delta_text and chunk_data.get("content"): + delta_text = chunk_data["content"] + if delta_text: full_text += delta_text - event = {"type": "content", "delta": delta_text} - yield f"event: content\ndata: {json.dumps(event)}\n\n" + yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n" - # Extract chat_id/response_id - if "id" in chunk_data and chunk_data["id"]: - chat_id = chunk_data["id"] - elif "response_id" in chunk_data and chunk_data["response_id"]: - chat_id = chunk_data["response_id"] - - # Extract sources from tool call results - if "item" in chunk_data and isinstance(chunk_data["item"], dict): - item = chunk_data["item"] - if item.get("type") in ("retrieval_call", "tool_call", "function_call"): - results = item.get("results", []) - if results: - for result in results: - if isinstance(result, dict): - source = { - "filename": result.get("filename", result.get("title", "Unknown")), - "text": result.get("text", result.get("content", "")), - "score": result.get("score", 0), - "page": result.get("page"), - } - sources.append(source) + # Extract chat_id/response_id from various fields + if not chat_id: + chat_id = chunk_data.get("id") or chunk_data.get("response_id") except json.JSONDecodeError: - # Not JSON, might be raw text - if chunk_str and not chunk_str.startswith("{"): - event = {"type": "content", "delta": chunk_str} - yield f"event: content\ndata: {json.dumps(event)}\n\n" + # Raw text without JSON wrapper + if chunk_str: + yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n" full_text += chunk_str except Exception as e: - logger.warning("Error processing stream chunk", error=str(e)) - continue + logger.warning("Error processing stream chunk", error=str(e), chunk=chunk_str[:100] if chunk_str else "") - # Yield sources event if we have any - if sources: - event = {"type": "sources", "sources": sources} - yield f"event: sources\ndata: {json.dumps(event)}\n\n" - - # Yield done event with chat_id - event = {"type": "done", "chat_id": chat_id} - yield f"event: done\ndata: {json.dumps(event)}\n\n" - - # Store chat_id for caller + yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n" chat_id_container["chat_id"] = chat_id async def chat_create_endpoint(request: Request, chat_service, session_manager): """ - Send a chat message. + Send a chat message via Langflow. POST /v1/chat - - Request body: - { - "message": "What is RAG?", - "stream": false, // optional, default false - "chat_id": "...", // optional, to continue conversation - "filters": {...}, // optional - "limit": 10, // optional - "score_threshold": 0.5 // optional - } - - Non-streaming response: - { - "response": "RAG stands for...", - "chat_id": "chat_xyz789", - "sources": [...] - } - - Streaming response (SSE): - event: content - data: {"type": "content", "delta": "RAG stands for"} - - event: sources - data: {"type": "sources", "sources": [...]} - - event: done - data: {"type": "done", "chat_id": "chat_xyz789"} """ try: data = await request.json() except Exception: - return JSONResponse( - {"error": "Invalid JSON in request body"}, - status_code=400, - ) + return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400) message = data.get("message", "").strip() if not message: - return JSONResponse( - {"error": "Message is required"}, - status_code=400, - ) + return JSONResponse({"error": "Message is required"}, status_code=400) stream = data.get("stream", False) - chat_id = data.get("chat_id") # For conversation continuation + chat_id = data.get("chat_id") filters = data.get("filters") limit = data.get("limit", 10) score_threshold = data.get("score_threshold", 0) + filter_id = data.get("filter_id") user = request.state.user user_id = user.user_id - - # Note: API key auth doesn't have JWT, so we pass None - jwt_token = None + jwt_token = session_manager.get_effective_jwt_token(user_id, None) # Set context variables for search tool if filters: @@ -176,45 +114,35 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager): set_auth_context(user_id, jwt_token) if stream: - # Streaming response - raw_stream = await chat_service.chat( + raw_stream = await chat_service.langflow_chat( prompt=message, user_id=user_id, jwt_token=jwt_token, previous_response_id=chat_id, stream=True, + filter_id=filter_id, ) - chat_id_container = {} - return StreamingResponse( _transform_stream_to_sse(raw_stream, chat_id_container), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, + headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) else: - # Non-streaming response - result = await chat_service.chat( + result = await chat_service.langflow_chat( prompt=message, user_id=user_id, jwt_token=jwt_token, previous_response_id=chat_id, stream=False, + filter_id=filter_id, ) - - # Transform response to public API format - # Internal format: {"response": "...", "response_id": "..."} - response_data = { + # Transform response_id to chat_id for v1 API format + return JSONResponse({ "response": result.get("response", ""), "chat_id": result.get("response_id"), "sources": result.get("sources", []), - } - - return JSONResponse(response_data) + }) async def chat_list_endpoint(request: Request, chat_service, session_manager): @@ -240,8 +168,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager): user_id = user.user_id try: - # Get chat history - history = await chat_service.get_chat_history(user_id) + # Get Langflow chat history (since v1 routes through Langflow) + history = await chat_service.get_langflow_history(user_id) # Transform to public API format conversations = [] @@ -293,8 +221,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager): ) try: - # Get chat history and find the specific conversation - history = await chat_service.get_chat_history(user_id) + # Get Langflow chat history and find the specific conversation + history = await chat_service.get_langflow_history(user_id) conversation = None for conv in history.get("conversations", []):