Merge pull request #724 from langflow-ai/sdk-chat-fix

sdk chat endpoint fix
This commit is contained in:
Sebastián Estévez 2025-12-22 23:15:47 -05:00 committed by GitHub
commit a252c94312
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2,7 +2,7 @@
Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history.
Uses API key authentication.
Uses API key authentication. Routes through Langflow (chat_service.langflow_chat).
"""
import json
from starlette.requests import Request
@ -15,25 +15,18 @@ logger = get_logger(__name__)
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
"""
Transform the raw internal streaming format to clean SSE events.
Transform the raw Langflow streaming format to clean SSE events for v1 API.
Yields SSE events in the format:
event: content
data: {"type": "content", "delta": "..."}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "..."}
"""
full_text = ""
sources = []
chat_id = None
async for chunk in raw_stream:
try:
# Decode the chunk
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip()
else:
@ -42,131 +35,76 @@ async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
if not chunk_str:
continue
# Parse the JSON chunk
chunk_data = json.loads(chunk_str)
# Extract text delta
delta_text = None
# Extract text from various possible formats
delta_text = ""
# Format 1: delta.content (OpenAI-style)
if "delta" in chunk_data:
delta = chunk_data["delta"]
if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or ""
delta_text = delta.get("content", "") or delta.get("text", "")
elif isinstance(delta, str):
delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]:
# Format 2: output_text (Langflow-style)
if not delta_text and chunk_data.get("output_text"):
delta_text = chunk_data["output_text"]
# Yield content event if we have text
# Format 3: text field directly
if not delta_text and chunk_data.get("text"):
delta_text = chunk_data["text"]
# Format 4: content field directly
if not delta_text and chunk_data.get("content"):
delta_text = chunk_data["content"]
if delta_text:
full_text += delta_text
event = {"type": "content", "delta": delta_text}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n"
# Extract chat_id/response_id
if "id" in chunk_data and chunk_data["id"]:
chat_id = chunk_data["id"]
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
# Extract sources from tool call results
if "item" in chunk_data and isinstance(chunk_data["item"], dict):
item = chunk_data["item"]
if item.get("type") in ("retrieval_call", "tool_call", "function_call"):
results = item.get("results", [])
if results:
for result in results:
if isinstance(result, dict):
source = {
"filename": result.get("filename", result.get("title", "Unknown")),
"text": result.get("text", result.get("content", "")),
"score": result.get("score", 0),
"page": result.get("page"),
}
sources.append(source)
# Extract chat_id/response_id from various fields
if not chat_id:
chat_id = chunk_data.get("id") or chunk_data.get("response_id")
except json.JSONDecodeError:
# Not JSON, might be raw text
if chunk_str and not chunk_str.startswith("{"):
event = {"type": "content", "delta": chunk_str}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
# Raw text without JSON wrapper
if chunk_str:
yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n"
full_text += chunk_str
except Exception as e:
logger.warning("Error processing stream chunk", error=str(e))
continue
logger.warning("Error processing stream chunk", error=str(e), chunk=chunk_str[:100] if chunk_str else "")
# Yield sources event if we have any
if sources:
event = {"type": "sources", "sources": sources}
yield f"event: sources\ndata: {json.dumps(event)}\n\n"
# Yield done event with chat_id
event = {"type": "done", "chat_id": chat_id}
yield f"event: done\ndata: {json.dumps(event)}\n\n"
# Store chat_id for caller
yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n"
chat_id_container["chat_id"] = chat_id
async def chat_create_endpoint(request: Request, chat_service, session_manager):
"""
Send a chat message.
Send a chat message via Langflow.
POST /v1/chat
Request body:
{
"message": "What is RAG?",
"stream": false, // optional, default false
"chat_id": "...", // optional, to continue conversation
"filters": {...}, // optional
"limit": 10, // optional
"score_threshold": 0.5 // optional
}
Non-streaming response:
{
"response": "RAG stands for...",
"chat_id": "chat_xyz789",
"sources": [...]
}
Streaming response (SSE):
event: content
data: {"type": "content", "delta": "RAG stands for"}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "chat_xyz789"}
"""
try:
data = await request.json()
except Exception:
return JSONResponse(
{"error": "Invalid JSON in request body"},
status_code=400,
)
return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400)
message = data.get("message", "").strip()
if not message:
return JSONResponse(
{"error": "Message is required"},
status_code=400,
)
return JSONResponse({"error": "Message is required"}, status_code=400)
stream = data.get("stream", False)
chat_id = data.get("chat_id") # For conversation continuation
chat_id = data.get("chat_id")
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0)
filter_id = data.get("filter_id")
user = request.state.user
user_id = user.user_id
# Note: API key auth doesn't have JWT, so we pass None
jwt_token = None
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
# Set context variables for search tool
if filters:
@ -176,45 +114,35 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
set_auth_context(user_id, jwt_token)
if stream:
# Streaming response
raw_stream = await chat_service.chat(
raw_stream = await chat_service.langflow_chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=True,
filter_id=filter_id,
)
chat_id_container = {}
return StreamingResponse(
_transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
else:
# Non-streaming response
result = await chat_service.chat(
result = await chat_service.langflow_chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=False,
filter_id=filter_id,
)
# Transform response to public API format
# Internal format: {"response": "...", "response_id": "..."}
response_data = {
# Transform response_id to chat_id for v1 API format
return JSONResponse({
"response": result.get("response", ""),
"chat_id": result.get("response_id"),
"sources": result.get("sources", []),
}
return JSONResponse(response_data)
})
async def chat_list_endpoint(request: Request, chat_service, session_manager):
@ -240,8 +168,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager):
user_id = user.user_id
try:
# Get chat history
history = await chat_service.get_chat_history(user_id)
# Get Langflow chat history (since v1 routes through Langflow)
history = await chat_service.get_langflow_history(user_id)
# Transform to public API format
conversations = []
@ -293,8 +221,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager):
)
try:
# Get chat history and find the specific conversation
history = await chat_service.get_chat_history(user_id)
# Get Langflow chat history and find the specific conversation
history = await chat_service.get_langflow_history(user_id)
conversation = None
for conv in history.get("conversations", []):