streaming fix

This commit is contained in:
phact 2025-12-22 22:47:18 -05:00
parent 3a73d5bce3
commit 0943ab5f2e

View file

@ -2,69 +2,142 @@
Public API v1 Chat endpoint. Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history. Provides chat functionality with streaming support and conversation history.
Uses API key authentication. Routes through Langflow endpoint. Uses API key authentication. Routes through Langflow (chat_service.langflow_chat).
""" """
import json import json
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse, StreamingResponse
from utils.logging_config import get_logger from utils.logging_config import get_logger
from api.chat import langflow_endpoint from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context
logger = get_logger(__name__) logger = get_logger(__name__)
def _transform_v1_request_to_internal(data: dict) -> dict: async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
"""Transform v1 API request format to internal Langflow format.""" """
return { Transform the raw Langflow streaming format to clean SSE events for v1 API.
"prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt"
"previous_response_id": data.get("chat_id"), # v1 uses "chat_id" Yields SSE events in the format:
"stream": data.get("stream", False), data: {"type": "content", "delta": "..."}
"filters": data.get("filters"), data: {"type": "sources", "sources": [...]}
"limit": data.get("limit", 10), data: {"type": "done", "chat_id": "..."}
"scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case """
"filter_id": data.get("filter_id"), full_text = ""
} sources = []
chat_id = None
async for chunk in raw_stream:
try:
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip()
else:
chunk_str = str(chunk).strip()
if not chunk_str:
continue
chunk_data = json.loads(chunk_str)
# Extract text delta
delta_text = None
if "delta" in chunk_data:
delta = chunk_data["delta"]
if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or ""
elif isinstance(delta, str):
delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]:
delta_text = chunk_data["output_text"]
if delta_text:
full_text += delta_text
yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n"
# Extract chat_id/response_id
if "id" in chunk_data and chunk_data["id"]:
chat_id = chunk_data["id"]
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
except json.JSONDecodeError:
if chunk_str and not chunk_str.startswith("{"):
yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n"
full_text += chunk_str
except Exception as e:
logger.warning("Error processing stream chunk", error=str(e))
continue
if sources:
yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n"
yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n"
chat_id_container["chat_id"] = chat_id
async def chat_create_endpoint(request: Request, chat_service, session_manager): async def chat_create_endpoint(request: Request, chat_service, session_manager):
""" """
Send a chat message. Routes to internal Langflow endpoint. Send a chat message via Langflow.
POST /v1/chat - see internal /langflow endpoint for full documentation. POST /v1/chat
Transforms v1 format (message, chat_id, score_threshold) to internal format.
""" """
try: try:
data = await request.json() data = await request.json()
except Exception: except Exception:
return JSONResponse( return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400)
{"error": "Invalid JSON in request body"},
status_code=400,
)
message = data.get("message", "").strip() message = data.get("message", "").strip()
if not message: if not message:
return JSONResponse( return JSONResponse({"error": "Message is required"}, status_code=400)
{"error": "Message is required"},
status_code=400, stream = data.get("stream", False)
chat_id = data.get("chat_id")
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0)
filter_id = data.get("filter_id")
user = request.state.user
user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
# Set context variables for search tool
if filters:
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
if stream:
raw_stream = await chat_service.langflow_chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=True,
filter_id=filter_id,
) )
chat_id_container = {}
# Transform v1 request to internal format return StreamingResponse(
internal_data = _transform_v1_request_to_internal(data) _transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream",
# Create a new request with transformed body for the internal endpoint headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
body = json.dumps(internal_data).encode() )
else:
async def receive(): result = await chat_service.langflow_chat(
return {"type": "http.request", "body": body} prompt=message,
user_id=user_id,
internal_request = Request(request.scope, receive) jwt_token=jwt_token,
previous_response_id=chat_id,
# Copy state attributes individually (state property has no setter) stream=False,
internal_request.state.user = request.state.user filter_id=filter_id,
internal_request.state.jwt_token = getattr(request.state, "jwt_token", None) )
# Transform response_id to chat_id for v1 API format
# Call internal Langflow endpoint return JSONResponse({
return await langflow_endpoint(internal_request, chat_service, session_manager) "response": result.get("response", ""),
"chat_id": result.get("response_id"),
"sources": result.get("sources", []),
})
async def chat_list_endpoint(request: Request, chat_service, session_manager): async def chat_list_endpoint(request: Request, chat_service, session_manager):