From 850c504c55e2823c363f04fed7d5ee0dd402ea1e Mon Sep 17 00:00:00 2001 From: phact Date: Mon, 22 Dec 2025 22:00:42 -0500 Subject: [PATCH] sdk chat endpoint fix --- src/api/v1/chat.py | 215 +++++++-------------------------------------- 1 file changed, 31 insertions(+), 184 deletions(-) diff --git a/src/api/v1/chat.py b/src/api/v1/chat.py index 4a8d3e3d..39d9b179 100644 --- a/src/api/v1/chat.py +++ b/src/api/v1/chat.py @@ -2,144 +2,36 @@ Public API v1 Chat endpoint. Provides chat functionality with streaming support and conversation history. -Uses API key authentication. +Uses API key authentication. Routes through Langflow endpoint. """ import json from starlette.requests import Request -from starlette.responses import JSONResponse, StreamingResponse +from starlette.responses import JSONResponse from utils.logging_config import get_logger -from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context +from api.chat import langflow_endpoint logger = get_logger(__name__) -async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): - """ - Transform the raw internal streaming format to clean SSE events. - - Yields SSE events in the format: - event: content - data: {"type": "content", "delta": "..."} - - event: sources - data: {"type": "sources", "sources": [...]} - - event: done - data: {"type": "done", "chat_id": "..."} - """ - full_text = "" - sources = [] - chat_id = None - - async for chunk in raw_stream: - try: - # Decode the chunk - if isinstance(chunk, bytes): - chunk_str = chunk.decode("utf-8").strip() - else: - chunk_str = str(chunk).strip() - - if not chunk_str: - continue - - # Parse the JSON chunk - chunk_data = json.loads(chunk_str) - - # Extract text delta - delta_text = None - if "delta" in chunk_data: - delta = chunk_data["delta"] - if isinstance(delta, dict): - delta_text = delta.get("content") or delta.get("text") or "" - elif isinstance(delta, str): - delta_text = delta - - if "output_text" in chunk_data and chunk_data["output_text"]: - delta_text = chunk_data["output_text"] - - # Yield content event if we have text - if delta_text: - full_text += delta_text - event = {"type": "content", "delta": delta_text} - yield f"event: content\ndata: {json.dumps(event)}\n\n" - - # Extract chat_id/response_id - if "id" in chunk_data and chunk_data["id"]: - chat_id = chunk_data["id"] - elif "response_id" in chunk_data and chunk_data["response_id"]: - chat_id = chunk_data["response_id"] - - # Extract sources from tool call results - if "item" in chunk_data and isinstance(chunk_data["item"], dict): - item = chunk_data["item"] - if item.get("type") in ("retrieval_call", "tool_call", "function_call"): - results = item.get("results", []) - if results: - for result in results: - if isinstance(result, dict): - source = { - "filename": result.get("filename", result.get("title", "Unknown")), - "text": result.get("text", result.get("content", "")), - "score": result.get("score", 0), - "page": result.get("page"), - } - sources.append(source) - - except json.JSONDecodeError: - # Not JSON, might be raw text - if chunk_str and not chunk_str.startswith("{"): - event = {"type": "content", "delta": chunk_str} - yield f"event: content\ndata: {json.dumps(event)}\n\n" - full_text += chunk_str - except Exception as e: - logger.warning("Error processing stream chunk", error=str(e)) - continue - - # Yield sources event if we have any - if sources: - event = {"type": "sources", "sources": sources} - yield f"event: sources\ndata: {json.dumps(event)}\n\n" - - # Yield done event with chat_id - event = {"type": "done", "chat_id": chat_id} - yield f"event: done\ndata: {json.dumps(event)}\n\n" - - # Store chat_id for caller - chat_id_container["chat_id"] = chat_id +def _transform_v1_request_to_internal(data: dict) -> dict: + """Transform v1 API request format to internal Langflow format.""" + return { + "prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt" + "previous_response_id": data.get("chat_id"), # v1 uses "chat_id" + "stream": data.get("stream", False), + "filters": data.get("filters"), + "limit": data.get("limit", 10), + "scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case + "filter_id": data.get("filter_id"), + } async def chat_create_endpoint(request: Request, chat_service, session_manager): """ - Send a chat message. + Send a chat message. Routes to internal Langflow endpoint. - POST /v1/chat - - Request body: - { - "message": "What is RAG?", - "stream": false, // optional, default false - "chat_id": "...", // optional, to continue conversation - "filters": {...}, // optional - "limit": 10, // optional - "score_threshold": 0.5 // optional - } - - Non-streaming response: - { - "response": "RAG stands for...", - "chat_id": "chat_xyz789", - "sources": [...] - } - - Streaming response (SSE): - event: content - data: {"type": "content", "delta": "RAG stands for"} - - event: sources - data: {"type": "sources", "sources": [...]} - - event: done - data: {"type": "done", "chat_id": "chat_xyz789"} + POST /v1/chat - see internal /langflow endpoint for full documentation. + Transforms v1 format (message, chat_id, score_threshold) to internal format. """ try: data = await request.json() @@ -156,65 +48,20 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager): status_code=400, ) - stream = data.get("stream", False) - chat_id = data.get("chat_id") # For conversation continuation - filters = data.get("filters") - limit = data.get("limit", 10) - score_threshold = data.get("score_threshold", 0) + # Transform v1 request to internal format + internal_data = _transform_v1_request_to_internal(data) - user = request.state.user - user_id = user.user_id + # Create a new request with transformed body for the internal endpoint + body = json.dumps(internal_data).encode() - # Note: API key auth doesn't have JWT, so we pass None - jwt_token = None + async def receive(): + return {"type": "http.request", "body": body} - # Set context variables for search tool - if filters: - set_search_filters(filters) - set_search_limit(limit) - set_score_threshold(score_threshold) - set_auth_context(user_id, jwt_token) + internal_request = Request(request.scope, receive) + internal_request.state = request.state # Copy state for auth - if stream: - # Streaming response - raw_stream = await chat_service.chat( - prompt=message, - user_id=user_id, - jwt_token=jwt_token, - previous_response_id=chat_id, - stream=True, - ) - - chat_id_container = {} - - return StreamingResponse( - _transform_stream_to_sse(raw_stream, chat_id_container), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - else: - # Non-streaming response - result = await chat_service.chat( - prompt=message, - user_id=user_id, - jwt_token=jwt_token, - previous_response_id=chat_id, - stream=False, - ) - - # Transform response to public API format - # Internal format: {"response": "...", "response_id": "..."} - response_data = { - "response": result.get("response", ""), - "chat_id": result.get("response_id"), - "sources": result.get("sources", []), - } - - return JSONResponse(response_data) + # Call internal Langflow endpoint + return await langflow_endpoint(internal_request, chat_service, session_manager) async def chat_list_endpoint(request: Request, chat_service, session_manager): @@ -240,8 +87,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager): user_id = user.user_id try: - # Get chat history - history = await chat_service.get_chat_history(user_id) + # Get Langflow chat history (since v1 routes through Langflow) + history = await chat_service.get_langflow_history(user_id) # Transform to public API format conversations = [] @@ -293,8 +140,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager): ) try: - # Get chat history and find the specific conversation - history = await chat_service.get_chat_history(user_id) + # Get Langflow chat history and find the specific conversation + history = await chat_service.get_langflow_history(user_id) conversation = None for conv in history.get("conversations", []):