sdk chat endpoint fix

This commit is contained in:
phact 2025-12-22 22:00:42 -05:00
parent 900e5f9795
commit 850c504c55

View file

@ -2,144 +2,36 @@
Public API v1 Chat endpoint. Public API v1 Chat endpoint.
Provides chat functionality with streaming support and conversation history. Provides chat functionality with streaming support and conversation history.
Uses API key authentication. Uses API key authentication. Routes through Langflow endpoint.
""" """
import json import json
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse from starlette.responses import JSONResponse
from utils.logging_config import get_logger from utils.logging_config import get_logger
from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context from api.chat import langflow_endpoint
logger = get_logger(__name__) logger = get_logger(__name__)
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict): def _transform_v1_request_to_internal(data: dict) -> dict:
""" """Transform v1 API request format to internal Langflow format."""
Transform the raw internal streaming format to clean SSE events. return {
"prompt": data.get("message", ""), # v1 uses "message", internal uses "prompt"
Yields SSE events in the format: "previous_response_id": data.get("chat_id"), # v1 uses "chat_id"
event: content "stream": data.get("stream", False),
data: {"type": "content", "delta": "..."} "filters": data.get("filters"),
"limit": data.get("limit", 10),
event: sources "scoreThreshold": data.get("score_threshold", 0), # v1 uses snake_case
data: {"type": "sources", "sources": [...]} "filter_id": data.get("filter_id"),
}
event: done
data: {"type": "done", "chat_id": "..."}
"""
full_text = ""
sources = []
chat_id = None
async for chunk in raw_stream:
try:
# Decode the chunk
if isinstance(chunk, bytes):
chunk_str = chunk.decode("utf-8").strip()
else:
chunk_str = str(chunk).strip()
if not chunk_str:
continue
# Parse the JSON chunk
chunk_data = json.loads(chunk_str)
# Extract text delta
delta_text = None
if "delta" in chunk_data:
delta = chunk_data["delta"]
if isinstance(delta, dict):
delta_text = delta.get("content") or delta.get("text") or ""
elif isinstance(delta, str):
delta_text = delta
if "output_text" in chunk_data and chunk_data["output_text"]:
delta_text = chunk_data["output_text"]
# Yield content event if we have text
if delta_text:
full_text += delta_text
event = {"type": "content", "delta": delta_text}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
# Extract chat_id/response_id
if "id" in chunk_data and chunk_data["id"]:
chat_id = chunk_data["id"]
elif "response_id" in chunk_data and chunk_data["response_id"]:
chat_id = chunk_data["response_id"]
# Extract sources from tool call results
if "item" in chunk_data and isinstance(chunk_data["item"], dict):
item = chunk_data["item"]
if item.get("type") in ("retrieval_call", "tool_call", "function_call"):
results = item.get("results", [])
if results:
for result in results:
if isinstance(result, dict):
source = {
"filename": result.get("filename", result.get("title", "Unknown")),
"text": result.get("text", result.get("content", "")),
"score": result.get("score", 0),
"page": result.get("page"),
}
sources.append(source)
except json.JSONDecodeError:
# Not JSON, might be raw text
if chunk_str and not chunk_str.startswith("{"):
event = {"type": "content", "delta": chunk_str}
yield f"event: content\ndata: {json.dumps(event)}\n\n"
full_text += chunk_str
except Exception as e:
logger.warning("Error processing stream chunk", error=str(e))
continue
# Yield sources event if we have any
if sources:
event = {"type": "sources", "sources": sources}
yield f"event: sources\ndata: {json.dumps(event)}\n\n"
# Yield done event with chat_id
event = {"type": "done", "chat_id": chat_id}
yield f"event: done\ndata: {json.dumps(event)}\n\n"
# Store chat_id for caller
chat_id_container["chat_id"] = chat_id
async def chat_create_endpoint(request: Request, chat_service, session_manager): async def chat_create_endpoint(request: Request, chat_service, session_manager):
""" """
Send a chat message. Send a chat message. Routes to internal Langflow endpoint.
POST /v1/chat POST /v1/chat - see internal /langflow endpoint for full documentation.
Transforms v1 format (message, chat_id, score_threshold) to internal format.
Request body:
{
"message": "What is RAG?",
"stream": false, // optional, default false
"chat_id": "...", // optional, to continue conversation
"filters": {...}, // optional
"limit": 10, // optional
"score_threshold": 0.5 // optional
}
Non-streaming response:
{
"response": "RAG stands for...",
"chat_id": "chat_xyz789",
"sources": [...]
}
Streaming response (SSE):
event: content
data: {"type": "content", "delta": "RAG stands for"}
event: sources
data: {"type": "sources", "sources": [...]}
event: done
data: {"type": "done", "chat_id": "chat_xyz789"}
""" """
try: try:
data = await request.json() data = await request.json()
@ -156,65 +48,20 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
status_code=400, status_code=400,
) )
stream = data.get("stream", False) # Transform v1 request to internal format
chat_id = data.get("chat_id") # For conversation continuation internal_data = _transform_v1_request_to_internal(data)
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("score_threshold", 0)
user = request.state.user # Create a new request with transformed body for the internal endpoint
user_id = user.user_id body = json.dumps(internal_data).encode()
# Note: API key auth doesn't have JWT, so we pass None async def receive():
jwt_token = None return {"type": "http.request", "body": body}
# Set context variables for search tool internal_request = Request(request.scope, receive)
if filters: internal_request.state = request.state # Copy state for auth
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
if stream: # Call internal Langflow endpoint
# Streaming response return await langflow_endpoint(internal_request, chat_service, session_manager)
raw_stream = await chat_service.chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=True,
)
chat_id_container = {}
return StreamingResponse(
_transform_stream_to_sse(raw_stream, chat_id_container),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
else:
# Non-streaming response
result = await chat_service.chat(
prompt=message,
user_id=user_id,
jwt_token=jwt_token,
previous_response_id=chat_id,
stream=False,
)
# Transform response to public API format
# Internal format: {"response": "...", "response_id": "..."}
response_data = {
"response": result.get("response", ""),
"chat_id": result.get("response_id"),
"sources": result.get("sources", []),
}
return JSONResponse(response_data)
async def chat_list_endpoint(request: Request, chat_service, session_manager): async def chat_list_endpoint(request: Request, chat_service, session_manager):
@ -240,8 +87,8 @@ async def chat_list_endpoint(request: Request, chat_service, session_manager):
user_id = user.user_id user_id = user.user_id
try: try:
# Get chat history # Get Langflow chat history (since v1 routes through Langflow)
history = await chat_service.get_chat_history(user_id) history = await chat_service.get_langflow_history(user_id)
# Transform to public API format # Transform to public API format
conversations = [] conversations = []
@ -293,8 +140,8 @@ async def chat_get_endpoint(request: Request, chat_service, session_manager):
) )
try: try:
# Get chat history and find the specific conversation # Get Langflow chat history and find the specific conversation
history = await chat_service.get_chat_history(user_id) history = await chat_service.get_langflow_history(user_id)
conversation = None conversation = None
for conv in history.get("conversations", []): for conv in history.get("conversations", []):