sdk chat endpoint fix

This commit is contained in:
phact 2025-12-22 22:00:42 -05:00
parent 900e5f9795
commit 850c504c55

View file

@ -2,144 +2,36 @@
Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history.
Uses API key authentication.
Uses API key authentication. Routes through Langflow endpoint.
"""
import json
from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse
from starlette.responses import JSONResponse
from utils.logging_config import get_logger
from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context
from api.chat import langflow_endpoint
logger = get_logger(__name__)
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
"""
Transform the raw internal streaming format to clean SSE events.
Yields SSE events in the format:
event: content
data: {"type": "content", "delta": "..."}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "..."}
"""
full_text = ""
sources = []
chat_id = None
async for chunk in raw_stream:
try:
# Decode the chunk
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip()
else:
chunk_str = str(chunk).strip()
if not chunk_str:
continue
# Parse the JSON chunk
chunk_data = json.loads(chunk_str)
# Extract text delta
delta_text = None
if "delta" in chunk_data:
delta = chunk_data["delta"]
if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or ""
elif isinstance(delta, str):
delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]:
delta_text = chunk_data["output_text"]
# Yield content event if we have text
if delta_text:
full_text += delta_text
event = {"type": "content", "delta": delta_text}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
# Extract chat_id/response_id
if "id" in chunk_data and chunk_data["id"]:
chat_id = chunk_data["id"]
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
# Extract sources from tool call results
if "item" in chunk_data and isinstance(chunk_data["item"], dict):
item = chunk_data["item"]
if item.get("type") in ("retrieval_call", "tool_call", "function_call"):
results = item.get("results", [])
if results:
for result in results:
if isinstance(result, dict):
source = {
"filename": result.get("filename", result.get("title", "Unknown")),
"text": result.get("text", result.get("content", "")),
"score": result.get("score", 0),
"page": result.get("page"),
}
sources.append(source)
except json.JSONDecodeError:
# Not JSON, might be raw text
if chunk_str and not chunk_str.startswith("{"):
event = {"type": "content", "delta": chunk_str}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
full_text += chunk_str
except Exception as e:
logger.warning("Error processing stream chunk", error=str(e))
continue
# Yield sources event if we have any
if sources:
event = {"type": "sources", "sources": sources}
yield f"event: sources\ndata: {json.dumps(event)}\n\n"
# Yield done event with chat_id
event = {"type": "done", "chat_id": chat_id}
yield f"event: done\ndata: {json.dumps(event)}\n\n"
# Store chat_id for caller
chat_id_container["chat_id"] = chat_id
def _transform_v1_request_to_internal(data: dict) -> dict:
"""Transform v1 API request format to internal Langflow format."""
return {
"prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt"
"previous_response_id": data.get("chat_id"), # v1 uses "chat_id"
"stream": data.get("stream", False),
"filters": data.get("filters"),
"limit": data.get("limit", 10),
"scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case
"filter_id": data.get("filter_id"),
}
async def chat_create_endpoint(request: Request, chat_service, session_manager):
"""
Send a chat message.
Send a chat message. Routes to internal Langflow endpoint.
POST /v1/chat
Request body:
{
"message": "What is RAG?",
"stream": false, // optional, default false
"chat_id": "...", // optional, to continue conversation
"filters": {...}, // optional
"limit": 10, // optional
"score_threshold": 0.5 // optional
}
Non-streaming response:
{
"response": "RAG stands for...",
"chat_id": "chat_xyz789",
"sources": [...]
}
Streaming response (SSE):
event: content
data: {"type": "content", "delta": "RAG stands for"}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "chat_xyz789"}
POST /v1/chat - see internal /langflow endpoint for full documentation.
Transforms v1 format (message, chat_id, score_threshold) to internal format.
"""
try:
data = await request.json()
@ -156,65 +48,20 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
status_code=400,
)
stream = data.get("stream", False)
chat_id = data.get("chat_id") # For conversation continuation
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0)
# Transform v1 request to internal format
internal_data = _transform_v1_request_to_internal(data)
user = request.state.user
user_id = user.user_id
# Create a new request with transformed body for the internal endpoint
body = json.dumps(internal_data).encode()
# Note: API key auth doesn't have JWT, so we pass None
jwt_token = None
async def receive():
return {"type": "http.request", "body": body}
# Set context variables for search tool
if filters:
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
internal_request = Request(request.scope, receive)
internal_request.state = request.state # Copy state for auth
if stream:
# Streaming response
raw_stream = await chat_service.chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=True,
)
chat_id_container = {}
return StreamingResponse(
_transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
else:
# Non-streaming response
result = await chat_service.chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=False,
)
# Transform response to public API format
# Internal format: {"response": "...", "response_id": "..."}
response_data = {
"response": result.get("response", ""),
"chat_id": result.get("response_id"),
"sources": result.get("sources", []),
}
return JSONResponse(response_data)
# Call internal Langflow endpoint
return await langflow_endpoint(internal_request, chat_service, session_manager)
async def chat_list_endpoint(request: Request, chat_service, session_manager):
@ -240,8 +87,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager):
user_id = user.user_id
try:
# Get chat history
history = await chat_service.get_chat_history(user_id)
# Get Langflow chat history (since v1 routes through Langflow)
history = await chat_service.get_langflow_history(user_id)
# Transform to public API format
conversations = []
@ -293,8 +140,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager):
)
try:
# Get chat history and find the specific conversation
history = await chat_service.get_chat_history(user_id)
# Get Langflow chat history and find the specific conversation
history = await chat_service.get_langflow_history(user_id)
conversation = None
for conv in history.get("conversations", []):