Merge pull request #724 from langflow-ai/sdk-chat-fix

sdk chat endpoint fix
This commit is contained in:
Sebastián Estévez 2025-12-22 23:15:47 -05:00 committed by GitHub
commit a252c94312
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,7 +2,7 @@
Public API v1 Chat endpoint. Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history. Provides chat functionality with streaming support and conversation history.
Uses API key authentication. Uses API key authentication. Routes through Langflow (chat_service.langflow_chat).
""" """
import json import json
from starlette.requests import Request from starlette.requests import Request
@ -15,25 +15,18 @@ logger = get_logger(__name__)
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
""" """
Transform the raw internal streaming format to clean SSE events. Transform the raw Langflow streaming format to clean SSE events for v1 API.
Yields SSE events in the format: Yields SSE events in the format:
event: content
data: {"type": "content", "delta": "..."} data: {"type": "content", "delta": "..."}
event: sources
data: {"type": "sources", "sources": [...]} data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "..."} data: {"type": "done", "chat_id": "..."}
""" """
full_text = "" full_text = ""
sources = []
chat_id = None chat_id = None
async for chunk in raw_stream: async for chunk in raw_stream:
try: try:
# Decode the chunk
if isinstance(chunk, bytes): if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip() chunk_str = chunk.decode("utf-8").strip()
else: else:
@ -42,131 +35,76 @@ async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
if not chunk_str: if not chunk_str:
continue continue
# Parse the JSON chunk
chunk_data = json.loads(chunk_str) chunk_data = json.loads(chunk_str)
# Extract text delta # Extract text from various possible formats
delta_text = None delta_text = ""
# Format 1: delta.content (OpenAI-style)
if "delta" in chunk_data: if "delta" in chunk_data:
delta = chunk_data["delta"] delta = chunk_data["delta"]
if isinstance(delta, dict): if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or "" delta_text = delta.get("content", "") or delta.get("text", "")
elif isinstance(delta, str): elif isinstance(delta, str):
delta_text = delta delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]: # Format 2: output_text (Langflow-style)
if not delta_text and chunk_data.get("output_text"):
delta_text = chunk_data["output_text"] delta_text = chunk_data["output_text"]
# Yield content event if we have text # Format 3: text field directly
if not delta_text and chunk_data.get("text"):
delta_text = chunk_data["text"]
# Format 4: content field directly
if not delta_text and chunk_data.get("content"):
delta_text = chunk_data["content"]
if delta_text: if delta_text:
full_text += delta_text full_text += delta_text
event = {"type": "content", "delta": delta_text} yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n"
yield f"event: content\ndata: {json.dumps(event)}\n\n"
# Extract chat_id/response_id # Extract chat_id/response_id from various fields
if "id" in chunk_data and chunk_data["id"]: if not chat_id:
chat_id = chunk_data["id"] chat_id = chunk_data.get("id") or chunk_data.get("response_id")
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
# Extract sources from tool call results
if "item" in chunk_data and isinstance(chunk_data["item"], dict):
item = chunk_data["item"]
if item.get("type") in ("retrieval_call", "tool_call", "function_call"):
results = item.get("results", [])
if results:
for result in results:
if isinstance(result, dict):
source = {
"filename": result.get("filename", result.get("title", "Unknown")),
"text": result.get("text", result.get("content", "")),
"score": result.get("score", 0),
"page": result.get("page"),
}
sources.append(source)
except json.JSONDecodeError: except json.JSONDecodeError:
# Not JSON, might be raw text # Raw text without JSON wrapper
if chunk_str and not chunk_str.startswith("{"): if chunk_str:
event = {"type": "content", "delta": chunk_str} yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n"
yield f"event: content\ndata: {json.dumps(event)}\n\n"
full_text += chunk_str full_text += chunk_str
except Exception as e: except Exception as e:
logger.warning("Error processing stream chunk", error=str(e)) logger.warning("Error processing stream chunk", error=str(e), chunk=chunk_str[:100] if chunk_str else "")
continue
# Yield sources event if we have any yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n"
if sources:
event = {"type": "sources", "sources": sources}
yield f"event: sources\ndata: {json.dumps(event)}\n\n"
# Yield done event with chat_id
event = {"type": "done", "chat_id": chat_id}
yield f"event: done\ndata: {json.dumps(event)}\n\n"
# Store chat_id for caller
chat_id_container["chat_id"] = chat_id chat_id_container["chat_id"] = chat_id
async def chat_create_endpoint(request: Request, chat_service, session_manager): async def chat_create_endpoint(request: Request, chat_service, session_manager):
""" """
Send a chat message. Send a chat message via Langflow.
POST /v1/chat POST /v1/chat
Request body:
{
"message": "What is RAG?",
"stream": false, // optional, default false
"chat_id": "...", // optional, to continue conversation
"filters": {...}, // optional
"limit": 10, // optional
"score_threshold": 0.5 // optional
}
Non-streaming response:
{
"response": "RAG stands for...",
"chat_id": "chat_xyz789",
"sources": [...]
}
Streaming response (SSE):
event: content
data: {"type": "content", "delta": "RAG stands for"}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "chat_xyz789"}
""" """
try: try:
data = await request.json() data = await request.json()
except Exception: except Exception:
return JSONResponse( return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400)
{"error": "Invalid JSON in request body"},
status_code=400,
)
message = data.get("message", "").strip() message = data.get("message", "").strip()
if not message: if not message:
return JSONResponse( return JSONResponse({"error": "Message is required"}, status_code=400)
{"error": "Message is required"},
status_code=400,
)
stream = data.get("stream", False) stream = data.get("stream", False)
chat_id = data.get("chat_id") # For conversation continuation chat_id = data.get("chat_id")
filters = data.get("filters") filters = data.get("filters")
limit = data.get("limit", 10) limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0) score_threshold = data.get("score_threshold", 0)
filter_id = data.get("filter_id")
user = request.state.user user = request.state.user
user_id = user.user_id user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
# Note: API key auth doesn't have JWT, so we pass None
jwt_token = None
# Set context variables for search tool # Set context variables for search tool
if filters: if filters:
@ -176,45 +114,35 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
set_auth_context(user_id, jwt_token) set_auth_context(user_id, jwt_token)
if stream: if stream:
# Streaming response raw_stream = await chat_service.langflow_chat(
raw_stream = await chat_service.chat(
prompt=message, prompt=message,
user_id=user_id, user_id=user_id,
jwt_token=jwt_token, jwt_token=jwt_token,
previous_response_id=chat_id, previous_response_id=chat_id,
stream=True, stream=True,
filter_id=filter_id,
) )
chat_id_container = {} chat_id_container = {}
return StreamingResponse( return StreamingResponse(
_transform_stream_to_sse(raw_stream, chat_id_container), _transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
) )
else: else:
# Non-streaming response result = await chat_service.langflow_chat(
result = await chat_service.chat(
prompt=message, prompt=message,
user_id=user_id, user_id=user_id,
jwt_token=jwt_token, jwt_token=jwt_token,
previous_response_id=chat_id, previous_response_id=chat_id,
stream=False, stream=False,
filter_id=filter_id,
) )
# Transform response_id to chat_id for v1 API format
# Transform response to public API format return JSONResponse({
# Internal format: {"response": "...", "response_id": "..."}
response_data = {
"response": result.get("response", ""), "response": result.get("response", ""),
"chat_id": result.get("response_id"), "chat_id": result.get("response_id"),
"sources": result.get("sources", []), "sources": result.get("sources", []),
} })
return JSONResponse(response_data)
async def chat_list_endpoint(request: Request, chat_service, session_manager): async def chat_list_endpoint(request: Request, chat_service, session_manager):
@ -240,8 +168,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager):
user_id = user.user_id user_id = user.user_id
try: try:
# Get chat history # Get Langflow chat history (since v1 routes through Langflow)
history = await chat_service.get_chat_history(user_id) history = await chat_service.get_langflow_history(user_id)
# Transform to public API format # Transform to public API format
conversations = [] conversations = []
@ -293,8 +221,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager):
) )
try: try:
# Get chat history and find the specific conversation # Get Langflow chat history and find the specific conversation
history = await chat_service.get_chat_history(user_id) history = await chat_service.get_langflow_history(user_id)
conversation = None conversation = None
for conv in history.get("conversations", []): for conv in history.get("conversations", []):