373 lines
12 KiB
Python
373 lines
12 KiB
Python
"""
|
|
Public API v1 Chat endpoint.
|
|
|
|
Provides chat functionality with streaming support and conversation history.
|
|
Uses API key authentication.
|
|
"""
|
|
import json
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, StreamingResponse
|
|
from utils.logging_config import get_logger
|
|
from auth_context import set_search_filters, set_search_limit, set_score_threshold, set_auth_context
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
async def _transform_stream_to_sse(raw_stream, chat_id_container: dict):
|
|
"""
|
|
Transform the raw internal streaming format to clean SSE events.
|
|
|
|
Yields SSE events in the format:
|
|
event: content
|
|
data: {"type": "content", "delta": "..."}
|
|
|
|
event: sources
|
|
data: {"type": "sources", "sources": [...]}
|
|
|
|
event: done
|
|
data: {"type": "done", "chat_id": "..."}
|
|
"""
|
|
full_text = ""
|
|
sources = []
|
|
chat_id = None
|
|
|
|
async for chunk in raw_stream:
|
|
try:
|
|
# Decode the chunk
|
|
if isinstance(chunk, bytes):
|
|
chunk_str = chunk.decode("utf-8").strip()
|
|
else:
|
|
chunk_str = str(chunk).strip()
|
|
|
|
if not chunk_str:
|
|
continue
|
|
|
|
# Parse the JSON chunk
|
|
chunk_data = json.loads(chunk_str)
|
|
|
|
# Extract text delta
|
|
delta_text = None
|
|
if "delta" in chunk_data:
|
|
delta = chunk_data["delta"]
|
|
if isinstance(delta, dict):
|
|
delta_text = delta.get("content") or delta.get("text") or ""
|
|
elif isinstance(delta, str):
|
|
delta_text = delta
|
|
|
|
if "output_text" in chunk_data and chunk_data["output_text"]:
|
|
delta_text = chunk_data["output_text"]
|
|
|
|
# Yield content event if we have text
|
|
if delta_text:
|
|
full_text += delta_text
|
|
event = {"type": "content", "delta": delta_text}
|
|
yield f"event: content\ndata: {json.dumps(event)}\n\n"
|
|
|
|
# Extract chat_id/response_id
|
|
if "id" in chunk_data and chunk_data["id"]:
|
|
chat_id = chunk_data["id"]
|
|
elif "response_id" in chunk_data and chunk_data["response_id"]:
|
|
chat_id = chunk_data["response_id"]
|
|
|
|
# Extract sources from tool call results
|
|
if "item" in chunk_data and isinstance(chunk_data["item"], dict):
|
|
item = chunk_data["item"]
|
|
if item.get("type") in ("retrieval_call", "tool_call", "function_call"):
|
|
results = item.get("results", [])
|
|
if results:
|
|
for result in results:
|
|
if isinstance(result, dict):
|
|
source = {
|
|
"filename": result.get("filename", result.get("title", "Unknown")),
|
|
"text": result.get("text", result.get("content", "")),
|
|
"score": result.get("score", 0),
|
|
"page": result.get("page"),
|
|
}
|
|
sources.append(source)
|
|
|
|
except json.JSONDecodeError:
|
|
# Not JSON, might be raw text
|
|
if chunk_str and not chunk_str.startswith("{"):
|
|
event = {"type": "content", "delta": chunk_str}
|
|
yield f"event: content\ndata: {json.dumps(event)}\n\n"
|
|
full_text += chunk_str
|
|
except Exception as e:
|
|
logger.warning("Error processing stream chunk", error=str(e))
|
|
continue
|
|
|
|
# Yield sources event if we have any
|
|
if sources:
|
|
event = {"type": "sources", "sources": sources}
|
|
yield f"event: sources\ndata: {json.dumps(event)}\n\n"
|
|
|
|
# Yield done event with chat_id
|
|
event = {"type": "done", "chat_id": chat_id}
|
|
yield f"event: done\ndata: {json.dumps(event)}\n\n"
|
|
|
|
# Store chat_id for caller
|
|
chat_id_container["chat_id"] = chat_id
|
|
|
|
|
|
async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
|
"""
|
|
Send a chat message.
|
|
|
|
POST /v1/chat
|
|
|
|
Request body:
|
|
{
|
|
"message": "What is RAG?",
|
|
"stream": false, // optional, default false
|
|
"chat_id": "...", // optional, to continue conversation
|
|
"filters": {...}, // optional
|
|
"limit": 10, // optional
|
|
"score_threshold": 0.5 // optional
|
|
}
|
|
|
|
Non-streaming response:
|
|
{
|
|
"response": "RAG stands for...",
|
|
"chat_id": "chat_xyz789",
|
|
"sources": [...]
|
|
}
|
|
|
|
Streaming response (SSE):
|
|
event: content
|
|
data: {"type": "content", "delta": "RAG stands for"}
|
|
|
|
event: sources
|
|
data: {"type": "sources", "sources": [...]}
|
|
|
|
event: done
|
|
data: {"type": "done", "chat_id": "chat_xyz789"}
|
|
"""
|
|
try:
|
|
data = await request.json()
|
|
except Exception:
|
|
return JSONResponse(
|
|
{"error": "Invalid JSON in request body"},
|
|
status_code=400,
|
|
)
|
|
|
|
message = data.get("message", "").strip()
|
|
if not message:
|
|
return JSONResponse(
|
|
{"error": "Message is required"},
|
|
status_code=400,
|
|
)
|
|
|
|
stream = data.get("stream", False)
|
|
chat_id = data.get("chat_id") # For conversation continuation
|
|
filters = data.get("filters")
|
|
limit = data.get("limit", 10)
|
|
score_threshold = data.get("score_threshold", 0)
|
|
|
|
user = request.state.user
|
|
user_id = user.user_id
|
|
|
|
# Note: API key auth doesn't have JWT, so we pass None
|
|
jwt_token = None
|
|
|
|
# Set context variables for search tool
|
|
if filters:
|
|
set_search_filters(filters)
|
|
set_search_limit(limit)
|
|
set_score_threshold(score_threshold)
|
|
set_auth_context(user_id, jwt_token)
|
|
|
|
if stream:
|
|
# Streaming response
|
|
raw_stream = await chat_service.chat(
|
|
prompt=message,
|
|
user_id=user_id,
|
|
jwt_token=jwt_token,
|
|
previous_response_id=chat_id,
|
|
stream=True,
|
|
)
|
|
|
|
chat_id_container = {}
|
|
|
|
return StreamingResponse(
|
|
_transform_stream_to_sse(raw_stream, chat_id_container),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
else:
|
|
# Non-streaming response
|
|
result = await chat_service.chat(
|
|
prompt=message,
|
|
user_id=user_id,
|
|
jwt_token=jwt_token,
|
|
previous_response_id=chat_id,
|
|
stream=False,
|
|
)
|
|
|
|
# Transform response to public API format
|
|
# Internal format: {"response": "...", "response_id": "..."}
|
|
response_data = {
|
|
"response": result.get("response", ""),
|
|
"chat_id": result.get("response_id"),
|
|
"sources": result.get("sources", []),
|
|
}
|
|
|
|
return JSONResponse(response_data)
|
|
|
|
|
|
async def chat_list_endpoint(request: Request, chat_service, session_manager):
|
|
"""
|
|
List all conversations for the authenticated user.
|
|
|
|
GET /v1/chat
|
|
|
|
Response:
|
|
{
|
|
"conversations": [
|
|
{
|
|
"chat_id": "...",
|
|
"title": "What is RAG?",
|
|
"created_at": "...",
|
|
"last_activity": "...",
|
|
"message_count": 5
|
|
}
|
|
]
|
|
}
|
|
"""
|
|
user = request.state.user
|
|
user_id = user.user_id
|
|
|
|
try:
|
|
# Get chat history
|
|
history = await chat_service.get_chat_history(user_id)
|
|
|
|
# Transform to public API format
|
|
conversations = []
|
|
for conv in history.get("conversations", []):
|
|
conversations.append({
|
|
"chat_id": conv.get("response_id"),
|
|
"title": conv.get("title", ""),
|
|
"created_at": conv.get("created_at"),
|
|
"last_activity": conv.get("last_activity"),
|
|
"message_count": conv.get("total_messages", 0),
|
|
})
|
|
|
|
return JSONResponse({"conversations": conversations})
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to list conversations", error=str(e), user_id=user_id)
|
|
return JSONResponse(
|
|
{"error": f"Failed to list conversations: {str(e)}"},
|
|
status_code=500,
|
|
)
|
|
|
|
|
|
async def chat_get_endpoint(request: Request, chat_service, session_manager):
|
|
"""
|
|
Get a specific conversation with full message history.
|
|
|
|
GET /v1/chat/{chat_id}
|
|
|
|
Response:
|
|
{
|
|
"chat_id": "...",
|
|
"title": "What is RAG?",
|
|
"created_at": "...",
|
|
"last_activity": "...",
|
|
"messages": [
|
|
{"role": "user", "content": "What is RAG?", "timestamp": "..."},
|
|
{"role": "assistant", "content": "RAG stands for...", "timestamp": "..."}
|
|
]
|
|
}
|
|
"""
|
|
user = request.state.user
|
|
user_id = user.user_id
|
|
chat_id = request.path_params.get("chat_id")
|
|
|
|
if not chat_id:
|
|
return JSONResponse(
|
|
{"error": "Chat ID is required"},
|
|
status_code=400,
|
|
)
|
|
|
|
try:
|
|
# Get chat history and find the specific conversation
|
|
history = await chat_service.get_chat_history(user_id)
|
|
|
|
conversation = None
|
|
for conv in history.get("conversations", []):
|
|
if conv.get("response_id") == chat_id:
|
|
conversation = conv
|
|
break
|
|
|
|
if not conversation:
|
|
return JSONResponse(
|
|
{"error": "Conversation not found"},
|
|
status_code=404,
|
|
)
|
|
|
|
# Transform to public API format
|
|
messages = []
|
|
for msg in conversation.get("messages", []):
|
|
messages.append({
|
|
"role": msg.get("role"),
|
|
"content": msg.get("content"),
|
|
"timestamp": msg.get("timestamp"),
|
|
})
|
|
|
|
response_data = {
|
|
"chat_id": conversation.get("response_id"),
|
|
"title": conversation.get("title", ""),
|
|
"created_at": conversation.get("created_at"),
|
|
"last_activity": conversation.get("last_activity"),
|
|
"messages": messages,
|
|
}
|
|
|
|
return JSONResponse(response_data)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get conversation", error=str(e), user_id=user_id, chat_id=chat_id)
|
|
return JSONResponse(
|
|
{"error": f"Failed to get conversation: {str(e)}"},
|
|
status_code=500,
|
|
)
|
|
|
|
|
|
async def chat_delete_endpoint(request: Request, chat_service, session_manager):
|
|
"""
|
|
Delete a conversation.
|
|
|
|
DELETE /v1/chat/{chat_id}
|
|
|
|
Response:
|
|
{"success": true}
|
|
"""
|
|
user = request.state.user
|
|
user_id = user.user_id
|
|
chat_id = request.path_params.get("chat_id")
|
|
|
|
if not chat_id:
|
|
return JSONResponse(
|
|
{"error": "Chat ID is required"},
|
|
status_code=400,
|
|
)
|
|
|
|
try:
|
|
result = await chat_service.delete_session(user_id, chat_id)
|
|
|
|
if result.get("success"):
|
|
return JSONResponse({"success": True})
|
|
else:
|
|
return JSONResponse(
|
|
{"error": result.get("error", "Failed to delete conversation")},
|
|
status_code=500,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to delete conversation", error=str(e), user_id=user_id, chat_id=chat_id)
|
|
return JSONResponse(
|
|
{"error": f"Failed to delete conversation: {str(e)}"},
|
|
status_code=500,
|
|
)
|