contexts endpoints and chat support via context vars

This commit is contained in:
estevez.sebastian@gmail.com 2025-08-12 15:01:46 -04:00
parent e87efaaeeb
commit 963c02e6a0
4 changed files with 337 additions and 1 deletions

View file

@ -7,6 +7,9 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
prompt = data.get("prompt", "")
previous_response_id = data.get("previous_response_id")
stream = data.get("stream", False)
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("scoreThreshold", 0)
user = request.state.user
user_id = user.user_id
@ -16,6 +19,15 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
# Set context variables for search tool (similar to search endpoint)
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
if stream:
return StreamingResponse(
@ -38,6 +50,12 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
prompt = data.get("prompt", "")
previous_response_id = data.get("previous_response_id")
stream = data.get("stream", False)
filters = data.get("filters")
limit = data.get("limit", 10)
score_threshold = data.get("scoreThreshold", 0)
user = request.state.user
user_id = user.user_id
# Get JWT token from request cookie
jwt_token = request.cookies.get("auth_token")
@ -45,6 +63,15 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
# Set context variables for search tool (similar to chat endpoint)
if filters:
from auth_context import set_search_filters
set_search_filters(filters)
from auth_context import set_search_limit, set_score_threshold
set_search_limit(limit)
set_score_threshold(score_threshold)
try:
if stream:
return StreamingResponse(

114
src/api/contexts.py Normal file
View file

@ -0,0 +1,114 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
import uuid
from datetime import datetime
async def create_context(request: Request, contexts_service, session_manager):
"""Create a new search context"""
payload = await request.json()
name = payload.get("name")
if not name:
return JSONResponse({"error": "Context name is required"}, status_code=400)
description = payload.get("description", "")
query_data = payload.get("queryData")
if not query_data:
return JSONResponse({"error": "Query data is required"}, status_code=400)
user = request.state.user
jwt_token = request.cookies.get("auth_token")
# Create context document
context_id = str(uuid.uuid4())
context_doc = {
"id": context_id,
"name": name,
"description": description,
"query_data": query_data, # Store the full search query JSON
"owner": user.user_id,
"allowed_users": payload.get("allowedUsers", []), # ACL field for future use
"allowed_groups": payload.get("allowedGroups", []), # ACL field for future use
"created_at": datetime.utcnow().isoformat(),
"updated_at": datetime.utcnow().isoformat()
}
result = await contexts_service.create_context(context_doc, user_id=user.user_id, jwt_token=jwt_token)
return JSONResponse(result)
async def search_contexts(request: Request, contexts_service, session_manager):
"""Search for contexts by name, description, or query content"""
payload = await request.json()
query = payload.get("query", "")
limit = payload.get("limit", 20)
user = request.state.user
jwt_token = request.cookies.get("auth_token")
result = await contexts_service.search_contexts(query, user_id=user.user_id, jwt_token=jwt_token, limit=limit)
return JSONResponse(result)
async def get_context(request: Request, contexts_service, session_manager):
"""Get a specific context by ID"""
context_id = request.path_params.get("context_id")
if not context_id:
return JSONResponse({"error": "Context ID is required"}, status_code=400)
user = request.state.user
jwt_token = request.cookies.get("auth_token")
result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
return JSONResponse(result)
async def update_context(request: Request, contexts_service, session_manager):
"""Update an existing context by delete + recreate (due to DLS limitations)"""
context_id = request.path_params.get("context_id")
if not context_id:
return JSONResponse({"error": "Context ID is required"}, status_code=400)
payload = await request.json()
user = request.state.user
jwt_token = request.cookies.get("auth_token")
# First, get the existing context
existing_result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
if not existing_result.get("success"):
return JSONResponse({"error": "Context not found or access denied"}, status_code=404)
existing_context = existing_result["context"]
# Delete the existing context
delete_result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
if not delete_result.get("success"):
return JSONResponse({"error": "Failed to delete existing context"}, status_code=500)
# Create updated context document with same ID
updated_context = {
"id": context_id,
"name": payload.get("name", existing_context["name"]),
"description": payload.get("description", existing_context["description"]),
"query_data": payload.get("queryData", existing_context["query_data"]),
"owner": existing_context["owner"],
"allowed_users": payload.get("allowedUsers", existing_context.get("allowed_users", [])),
"allowed_groups": payload.get("allowedGroups", existing_context.get("allowed_groups", [])),
"created_at": existing_context["created_at"], # Preserve original creation time
"updated_at": datetime.utcnow().isoformat()
}
# Recreate the context
result = await contexts_service.create_context(updated_context, user_id=user.user_id, jwt_token=jwt_token)
return JSONResponse(result)
async def delete_context(request: Request, contexts_service, session_manager):
"""Delete a context"""
context_id = request.path_params.get("context_id")
if not context_id:
return JSONResponse({"error": "Context ID is required"}, status_code=400)
user = request.state.user
jwt_token = request.cookies.get("auth_token")
result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
return JSONResponse(result)

View file

@ -23,6 +23,7 @@ from services.search_service import SearchService
from services.task_service import TaskService
from services.auth_service import AuthService
from services.chat_service import ChatService
from services.contexts_service import ContextsService
# Existing services
from connectors.service import ConnectorService
@ -30,7 +31,7 @@ from session_manager import SessionManager
from auth_middleware import require_auth, optional_auth
# API endpoints
from api import upload, search, chat, auth, connectors, tasks, oidc
from api import upload, search, chat, auth, connectors, tasks, oidc, contexts
print("CUDA available:", torch.cuda.is_available())
print("CUDA version PyTorch was built with:", torch.version.cuda)
@ -56,11 +57,36 @@ async def init_index():
"""Initialize OpenSearch index and security roles"""
await wait_for_opensearch()
# Create documents index
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
print(f"Created index '{INDEX_NAME}'")
else:
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
# Create contexts index
contexts_index_name = "search_contexts"
contexts_index_body = {
"mappings": {
"properties": {
"id": {"type": "keyword"},
"name": {"type": "text", "analyzer": "standard"},
"description": {"type": "text", "analyzer": "standard"},
"query_data": {"type": "text"}, # Store as text for searching
"owner": {"type": "keyword"},
"allowed_users": {"type": "keyword"},
"allowed_groups": {"type": "keyword"},
"created_at": {"type": "date"},
"updated_at": {"type": "date"}
}
}
}
if not await clients.opensearch.indices.exists(index=contexts_index_name):
await clients.opensearch.indices.create(index=contexts_index_name, body=contexts_index_body)
print(f"Created index '{contexts_index_name}'")
else:
print(f"Index '{contexts_index_name}' already exists, skipping creation.")
async def init_index_when_ready():
"""Initialize OpenSearch index when it becomes available"""
@ -85,6 +111,7 @@ def initialize_services():
search_service = SearchService(session_manager)
task_service = TaskService(document_service, process_pool)
chat_service = ChatService()
contexts_service = ContextsService(session_manager)
# Set process pool for document service
document_service.process_pool = process_pool
@ -109,6 +136,7 @@ def initialize_services():
'chat_service': chat_service,
'auth_service': auth_service,
'connector_service': connector_service,
'contexts_service': contexts_service,
'session_manager': session_manager
}
@ -170,6 +198,42 @@ def create_app():
session_manager=services['session_manager'])
), methods=["POST"]),
# Contexts endpoints
Route("/contexts",
require_auth(services['session_manager'])(
partial(contexts.create_context,
contexts_service=services['contexts_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/contexts/search",
require_auth(services['session_manager'])(
partial(contexts.search_contexts,
contexts_service=services['contexts_service'],
session_manager=services['session_manager'])
), methods=["POST"]),
Route("/contexts/{context_id}",
require_auth(services['session_manager'])(
partial(contexts.get_context,
contexts_service=services['contexts_service'],
session_manager=services['session_manager'])
), methods=["GET"]),
Route("/contexts/{context_id}",
require_auth(services['session_manager'])(
partial(contexts.update_context,
contexts_service=services['contexts_service'],
session_manager=services['session_manager'])
), methods=["PUT"]),
Route("/contexts/{context_id}",
require_auth(services['session_manager'])(
partial(contexts.delete_context,
contexts_service=services['contexts_service'],
session_manager=services['session_manager'])
), methods=["DELETE"]),
# Chat endpoints
Route("/chat",
require_auth(services['session_manager'])(

View file

@ -0,0 +1,131 @@
from typing import Any, Dict, Optional
CONTEXTS_INDEX_NAME = "search_contexts"
class ContextsService:
def __init__(self, session_manager=None):
self.session_manager = session_manager
async def create_context(self, context_doc: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
"""Create a new search context"""
try:
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
# Index the context document
result = await opensearch_client.index(
index=CONTEXTS_INDEX_NAME,
id=context_doc["id"],
body=context_doc
)
if result.get("result") == "created":
return {"success": True, "id": context_doc["id"], "context": context_doc}
else:
return {"success": False, "error": "Failed to create context"}
except Exception as e:
return {"success": False, "error": str(e)}
async def search_contexts(self, query: str, user_id: str = None, jwt_token: str = None, limit: int = 20) -> Dict[str, Any]:
"""Search for contexts by name, description, or query content"""
try:
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
if query.strip():
# Search across name, description, and query_data fields
search_body = {
"query": {
"multi_match": {
"query": query,
"fields": ["name^3", "description^2", "query_data"],
"type": "best_fields",
"fuzziness": "AUTO"
}
},
"sort": [
{"_score": {"order": "desc"}},
{"updated_at": {"order": "desc"}}
],
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
"size": limit
}
else:
# No query - return all contexts sorted by most recent
search_body = {
"query": {"match_all": {}},
"sort": [{"updated_at": {"order": "desc"}}],
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
"size": limit
}
result = await opensearch_client.search(index=CONTEXTS_INDEX_NAME, body=search_body)
# Transform results
contexts = []
for hit in result["hits"]["hits"]:
context = hit["_source"]
context["score"] = hit.get("_score")
contexts.append(context)
return {"success": True, "contexts": contexts}
except Exception as e:
return {"success": False, "error": str(e), "contexts": []}
async def get_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
"""Get a specific context by ID"""
try:
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
result = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
if result.get("found"):
context = result["_source"]
return {"success": True, "context": context}
else:
return {"success": False, "error": "Context not found"}
except Exception as e:
return {"success": False, "error": str(e)}
async def update_context(self, context_id: str, updates: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
"""Update an existing context"""
try:
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
# Update the document
result = await opensearch_client.update(
index=CONTEXTS_INDEX_NAME,
id=context_id,
body={"doc": updates}
)
if result.get("result") in ["updated", "noop"]:
# Get the updated document
updated_doc = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
return {"success": True, "context": updated_doc["_source"]}
else:
return {"success": False, "error": "Failed to update context"}
except Exception as e:
return {"success": False, "error": str(e)}
async def delete_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
"""Delete a context"""
try:
# Get user's OpenSearch client with JWT for OIDC auth
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
result = await opensearch_client.delete(index=CONTEXTS_INDEX_NAME, id=context_id)
if result.get("result") == "deleted":
return {"success": True, "message": "Context deleted successfully"}
else:
return {"success": False, "error": "Failed to delete context"}
except Exception as e:
return {"success": False, "error": str(e)}