streaming fix

This commit is contained in:
phact 2025-12-22 22:47:18 -05:00
parent 3a73d5bce3
commit 0943ab5f2e

View file

@ -2,69 +2,142 @@
Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history.
Uses API key authentication. Routes through Langflow endpoint.
Uses API key authentication. Routes through Langflow (chat_service.langflow_chat).
"""
import json
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import JSONResponse, StreamingResponse
from utils.logging_config import get_logger
from api.chat import langflow_endpoint
from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context
logger = get_logger(__name__)
def _transform_v1_request_to_internal(data: dict) -> dict:
"""Transform v1 API request format to internal Langflow format."""
return {
"prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt"
"previous_response_id": data.get("chat_id"), # v1 uses "chat_id"
"stream": data.get("stream", False),
"filters": data.get("filters"),
"limit": data.get("limit", 10),
"scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case
"filter_id": data.get("filter_id"),
}
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
"""
Transform the raw Langflow streaming format to clean SSE events for v1 API.
Yields SSE events in the format:
data: {"type": "content", "delta": "..."}
data: {"type": "sources", "sources": [...]}
data: {"type": "done", "chat_id": "..."}
"""
full_text = ""
sources = []
chat_id = None
async for chunk in raw_stream:
try:
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip()
else:
chunk_str = str(chunk).strip()
if not chunk_str:
continue
chunk_data = json.loads(chunk_str)
# Extract text delta
delta_text = None
if "delta" in chunk_data:
delta = chunk_data["delta"]
if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or ""
elif isinstance(delta, str):
delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]:
delta_text = chunk_data["output_text"]
if delta_text:
full_text += delta_text
yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n"
# Extract chat_id/response_id
if "id" in chunk_data and chunk_data["id"]:
chat_id = chunk_data["id"]
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
except json.JSONDecodeError:
if chunk_str and not chunk_str.startswith("{"):
yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n"
full_text += chunk_str
except Exception as e:
logger.warning("Error processing stream chunk", error=str(e))
continue
if sources:
yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n"
yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n"
chat_id_container["chat_id"] = chat_id
async def chat_create_endpoint(request: Request, chat_service, session_manager):
"""
Send a chat message. Routes to internal Langflow endpoint.
Send a chat message via Langflow.
POST /v1/chat - see internal /langflow endpoint for full documentation.
Transforms v1 format (message, chat_id, score_threshold) to internal format.
POST /v1/chat
"""
try:
data = await request.json()
except Exception:
return JSONResponse(
{"error": "Invalid JSON in request body"},
status_code=400,
)
return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400)
message = data.get("message", "").strip()
if not message:
return JSONResponse(
{"error": "Message is required"},
status_code=400,
return JSONResponse({"error": "Message is required"}, status_code=400)
stream = data.get("stream", False)
chat_id = data.get("chat_id")
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0)
filter_id = data.get("filter_id")
user = request.state.user
user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
# Set context variables for search tool
if filters:
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
if stream:
raw_stream = await chat_service.langflow_chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=True,
filter_id=filter_id,
)
# Transform v1 request to internal format
internal_data = _transform_v1_request_to_internal(data)
# Create a new request with transformed body for the internal endpoint
body = json.dumps(internal_data).encode()
async def receive():
return {"type": "http.request", "body": body}
internal_request = Request(request.scope, receive)
# Copy state attributes individually (state property has no setter)
internal_request.state.user = request.state.user
internal_request.state.jwt_token = getattr(request.state, "jwt_token", None)
# Call internal Langflow endpoint
return await langflow_endpoint(internal_request, chat_service, session_manager)
chat_id_container = {}
return StreamingResponse(
_transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
else:
result = await chat_service.langflow_chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=False,
filter_id=filter_id,
)
# Transform response_id to chat_id for v1 API format
return JSONResponse({
"response": result.get("response", ""),
"chat_id": result.get("response_id"),
"sources": result.get("sources", []),
})
async def chat_list_endpoint(request: Request, chat_service, session_manager):