streaming fix
This commit is contained in:
parent
3a73d5bce3
commit
0943ab5f2e
1 changed files with 115 additions and 42 deletions
|
|
@ -2,69 +2,142 @@
|
||||||
Public API v1 Chat endpoint.
|
Public API v1 Chat endpoint.
|
||||||
|
|
||||||
Provides chat functionality with streaming support and conversation history.
|
Provides chat functionality with streaming support and conversation history.
|
||||||
Uses API key authentication. Routes through Langflow endpoint.
|
Uses API key authentication. Routes through Langflow (chat_service.langflow_chat).
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse, StreamingResponse
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
from api.chat import langflow_endpoint
|
from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _transform_v1_request_to_internal(data: dict) -> dict:
|
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
|
||||||
"""Transform v1 API request format to internal Langflow format."""
|
"""
|
||||||
return {
|
Transform the raw Langflow streaming format to clean SSE events for v1 API.
|
||||||
"prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt"
|
|
||||||
"previous_response_id": data.get("chat_id"), # v1 uses "chat_id"
|
Yields SSE events in the format:
|
||||||
"stream": data.get("stream", False),
|
data: {"type": "content", "delta": "..."}
|
||||||
"filters": data.get("filters"),
|
data: {"type": "sources", "sources": [...]}
|
||||||
"limit": data.get("limit", 10),
|
data: {"type": "done", "chat_id": "..."}
|
||||||
"scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case
|
"""
|
||||||
"filter_id": data.get("filter_id"),
|
full_text = ""
|
||||||
}
|
sources = []
|
||||||
|
chat_id = None
|
||||||
|
|
||||||
|
async for chunk in raw_stream:
|
||||||
|
try:
|
||||||
|
if isinstance(chunk, bytes):
|
||||||
|
chunk_str = chunk.decode("utf-8").strip()
|
||||||
|
else:
|
||||||
|
chunk_str = str(chunk).strip()
|
||||||
|
|
||||||
|
if not chunk_str:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_data = json.loads(chunk_str)
|
||||||
|
|
||||||
|
# Extract text delta
|
||||||
|
delta_text = None
|
||||||
|
if "delta" in chunk_data:
|
||||||
|
delta = chunk_data["delta"]
|
||||||
|
if isinstance(delta, dict):
|
||||||
|
delta_text = delta.get("content") or delta.get("text") or ""
|
||||||
|
elif isinstance(delta, str):
|
||||||
|
delta_text = delta
|
||||||
|
|
||||||
|
if "output_text" in chunk_data and chunk_data["output_text"]:
|
||||||
|
delta_text = chunk_data["output_text"]
|
||||||
|
|
||||||
|
if delta_text:
|
||||||
|
full_text += delta_text
|
||||||
|
yield f"data: {json.dumps({'type': 'content', 'delta': delta_text})}\n\n"
|
||||||
|
|
||||||
|
# Extract chat_id/response_id
|
||||||
|
if "id" in chunk_data and chunk_data["id"]:
|
||||||
|
chat_id = chunk_data["id"]
|
||||||
|
elif "response_id" in chunk_data and chunk_data["response_id"]:
|
||||||
|
chat_id = chunk_data["response_id"]
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
if chunk_str and not chunk_str.startswith("{"):
|
||||||
|
yield f"data: {json.dumps({'type': 'content', 'delta': chunk_str})}\n\n"
|
||||||
|
full_text += chunk_str
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error processing stream chunk", error=str(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sources:
|
||||||
|
yield f"data: {json.dumps({'type': 'sources', 'sources': sources})}\n\n"
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'type': 'done', 'chat_id': chat_id})}\n\n"
|
||||||
|
chat_id_container["chat_id"] = chat_id
|
||||||
|
|
||||||
|
|
||||||
async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
"""
|
"""
|
||||||
Send a chat message. Routes to internal Langflow endpoint.
|
Send a chat message via Langflow.
|
||||||
|
|
||||||
POST /v1/chat - see internal /langflow endpoint for full documentation.
|
POST /v1/chat
|
||||||
Transforms v1 format (message, chat_id, score_threshold) to internal format.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except Exception:
|
except Exception:
|
||||||
return JSONResponse(
|
return JSONResponse({"error": "Invalid JSON in request body"}, status_code=400)
|
||||||
{"error": "Invalid JSON in request body"},
|
|
||||||
status_code=400,
|
|
||||||
)
|
|
||||||
|
|
||||||
message = data.get("message", "").strip()
|
message = data.get("message", "").strip()
|
||||||
if not message:
|
if not message:
|
||||||
return JSONResponse(
|
return JSONResponse({"error": "Message is required"}, status_code=400)
|
||||||
{"error": "Message is required"},
|
|
||||||
status_code=400,
|
stream = data.get("stream", False)
|
||||||
|
chat_id = data.get("chat_id")
|
||||||
|
filters = data.get("filters")
|
||||||
|
limit = data.get("limit", 10)
|
||||||
|
score_threshold = data.get("score_threshold", 0)
|
||||||
|
filter_id = data.get("filter_id")
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
user_id = user.user_id
|
||||||
|
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
|
||||||
|
|
||||||
|
# Set context variables for search tool
|
||||||
|
if filters:
|
||||||
|
set_search_filters(filters)
|
||||||
|
set_search_limit(limit)
|
||||||
|
set_score_threshold(score_threshold)
|
||||||
|
set_auth_context(user_id, jwt_token)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
raw_stream = await chat_service.langflow_chat(
|
||||||
|
prompt=message,
|
||||||
|
user_id=user_id,
|
||||||
|
jwt_token=jwt_token,
|
||||||
|
previous_response_id=chat_id,
|
||||||
|
stream=True,
|
||||||
|
filter_id=filter_id,
|
||||||
)
|
)
|
||||||
|
chat_id_container = {}
|
||||||
# Transform v1 request to internal format
|
return StreamingResponse(
|
||||||
internal_data = _transform_v1_request_to_internal(data)
|
_transform_stream_to_sse(raw_stream, chat_id_container),
|
||||||
|
media_type="text/event-stream",
|
||||||
# Create a new request with transformed body for the internal endpoint
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||||
body = json.dumps(internal_data).encode()
|
)
|
||||||
|
else:
|
||||||
async def receive():
|
result = await chat_service.langflow_chat(
|
||||||
return {"type": "http.request", "body": body}
|
prompt=message,
|
||||||
|
user_id=user_id,
|
||||||
internal_request = Request(request.scope, receive)
|
jwt_token=jwt_token,
|
||||||
|
previous_response_id=chat_id,
|
||||||
# Copy state attributes individually (state property has no setter)
|
stream=False,
|
||||||
internal_request.state.user = request.state.user
|
filter_id=filter_id,
|
||||||
internal_request.state.jwt_token = getattr(request.state, "jwt_token", None)
|
)
|
||||||
|
# Transform response_id to chat_id for v1 API format
|
||||||
# Call internal Langflow endpoint
|
return JSONResponse({
|
||||||
return await langflow_endpoint(internal_request, chat_service, session_manager)
|
"response": result.get("response", ""),
|
||||||
|
"chat_id": result.get("response_id"),
|
||||||
|
"sources": result.get("sources", []),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
async def chat_list_endpoint(request: Request, chat_service, session_manager):
|
async def chat_list_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue