From 963c02e6a02ba67c18ebb227fc6a9711fa359dce Mon Sep 17 00:00:00 2001 From: "estevez.sebastian@gmail.com" Date: Tue, 12 Aug 2025 15:01:46 -0400 Subject: [PATCH] contexts endpoints and chat support via context vars --- src/api/chat.py | 27 +++++++ src/api/contexts.py | 114 +++++++++++++++++++++++++++ src/main.py | 66 +++++++++++++++- src/services/contexts_service.py | 131 +++++++++++++++++++++++++++++++ 4 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 src/api/contexts.py create mode 100644 src/services/contexts_service.py diff --git a/src/api/chat.py b/src/api/chat.py index 8f6e2e59..3be0c3a8 100644 --- a/src/api/chat.py +++ b/src/api/chat.py @@ -7,6 +7,9 @@ async def chat_endpoint(request: Request, chat_service, session_manager): prompt = data.get("prompt", "") previous_response_id = data.get("previous_response_id") stream = data.get("stream", False) + filters = data.get("filters") + limit = data.get("limit", 10) + score_threshold = data.get("scoreThreshold", 0) user = request.state.user user_id = user.user_id @@ -16,6 +19,15 @@ async def chat_endpoint(request: Request, chat_service, session_manager): if not prompt: return JSONResponse({"error": "Prompt is required"}, status_code=400) + + # Set context variables for search tool (similar to search endpoint) + if filters: + from auth_context import set_search_filters + set_search_filters(filters) + + from auth_context import set_search_limit, set_score_threshold + set_search_limit(limit) + set_score_threshold(score_threshold) if stream: return StreamingResponse( @@ -38,6 +50,12 @@ async def langflow_endpoint(request: Request, chat_service, session_manager): prompt = data.get("prompt", "") previous_response_id = data.get("previous_response_id") stream = data.get("stream", False) + filters = data.get("filters") + limit = data.get("limit", 10) + score_threshold = data.get("scoreThreshold", 0) + + user = request.state.user + user_id = user.user_id # Get JWT token from request cookie jwt_token = request.cookies.get("auth_token") @@ -45,6 +63,15 @@ async def langflow_endpoint(request: Request, chat_service, session_manager): if not prompt: return JSONResponse({"error": "Prompt is required"}, status_code=400) + # Set context variables for search tool (similar to chat endpoint) + if filters: + from auth_context import set_search_filters + set_search_filters(filters) + + from auth_context import set_search_limit, set_score_threshold + set_search_limit(limit) + set_score_threshold(score_threshold) + try: if stream: return StreamingResponse( diff --git a/src/api/contexts.py b/src/api/contexts.py new file mode 100644 index 00000000..abd4ae43 --- /dev/null +++ b/src/api/contexts.py @@ -0,0 +1,114 @@ +from starlette.requests import Request +from starlette.responses import JSONResponse +import uuid +from datetime import datetime + +async def create_context(request: Request, contexts_service, session_manager): + """Create a new search context""" + payload = await request.json() + + name = payload.get("name") + if not name: + return JSONResponse({"error": "Context name is required"}, status_code=400) + + description = payload.get("description", "") + query_data = payload.get("queryData") + if not query_data: + return JSONResponse({"error": "Query data is required"}, status_code=400) + + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + # Create context document + context_id = str(uuid.uuid4()) + context_doc = { + "id": context_id, + "name": name, + "description": description, + "query_data": query_data, # Store the full search query JSON + "owner": user.user_id, + "allowed_users": payload.get("allowedUsers", []), # ACL field for future use + "allowed_groups": payload.get("allowedGroups", []), # ACL field for future use + "created_at": datetime.utcnow().isoformat(), + "updated_at": datetime.utcnow().isoformat() + } + + result = await contexts_service.create_context(context_doc, user_id=user.user_id, jwt_token=jwt_token) + return JSONResponse(result) + +async def search_contexts(request: Request, contexts_service, session_manager): + """Search for contexts by name, description, or query content""" + payload = await request.json() + + query = payload.get("query", "") + limit = payload.get("limit", 20) + + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + result = await contexts_service.search_contexts(query, user_id=user.user_id, jwt_token=jwt_token, limit=limit) + return JSONResponse(result) + +async def get_context(request: Request, contexts_service, session_manager): + """Get a specific context by ID""" + context_id = request.path_params.get("context_id") + if not context_id: + return JSONResponse({"error": "Context ID is required"}, status_code=400) + + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token) + return JSONResponse(result) + +async def update_context(request: Request, contexts_service, session_manager): + """Update an existing context by delete + recreate (due to DLS limitations)""" + context_id = request.path_params.get("context_id") + if not context_id: + return JSONResponse({"error": "Context ID is required"}, status_code=400) + + payload = await request.json() + + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + # First, get the existing context + existing_result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token) + if not existing_result.get("success"): + return JSONResponse({"error": "Context not found or access denied"}, status_code=404) + + existing_context = existing_result["context"] + + # Delete the existing context + delete_result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token) + if not delete_result.get("success"): + return JSONResponse({"error": "Failed to delete existing context"}, status_code=500) + + # Create updated context document with same ID + updated_context = { + "id": context_id, + "name": payload.get("name", existing_context["name"]), + "description": payload.get("description", existing_context["description"]), + "query_data": payload.get("queryData", existing_context["query_data"]), + "owner": existing_context["owner"], + "allowed_users": payload.get("allowedUsers", existing_context.get("allowed_users", [])), + "allowed_groups": payload.get("allowedGroups", existing_context.get("allowed_groups", [])), + "created_at": existing_context["created_at"], # Preserve original creation time + "updated_at": datetime.utcnow().isoformat() + } + + # Recreate the context + result = await contexts_service.create_context(updated_context, user_id=user.user_id, jwt_token=jwt_token) + return JSONResponse(result) + +async def delete_context(request: Request, contexts_service, session_manager): + """Delete a context""" + context_id = request.path_params.get("context_id") + if not context_id: + return JSONResponse({"error": "Context ID is required"}, status_code=400) + + user = request.state.user + jwt_token = request.cookies.get("auth_token") + + result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token) + return JSONResponse(result) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 367b2005..82e93efb 100644 --- a/src/main.py +++ b/src/main.py @@ -23,6 +23,7 @@ from services.search_service import SearchService from services.task_service import TaskService from services.auth_service import AuthService from services.chat_service import ChatService +from services.contexts_service import ContextsService # Existing services from connectors.service import ConnectorService @@ -30,7 +31,7 @@ from session_manager import SessionManager from auth_middleware import require_auth, optional_auth # API endpoints -from api import upload, search, chat, auth, connectors, tasks, oidc +from api import upload, search, chat, auth, connectors, tasks, oidc, contexts print("CUDA available:", torch.cuda.is_available()) print("CUDA version PyTorch was built with:", torch.version.cuda) @@ -56,11 +57,36 @@ async def init_index(): """Initialize OpenSearch index and security roles""" await wait_for_opensearch() + # Create documents index if not await clients.opensearch.indices.exists(index=INDEX_NAME): await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY) print(f"Created index '{INDEX_NAME}'") else: print(f"Index '{INDEX_NAME}' already exists, skipping creation.") + + # Create contexts index + contexts_index_name = "search_contexts" + contexts_index_body = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + "name": {"type": "text", "analyzer": "standard"}, + "description": {"type": "text", "analyzer": "standard"}, + "query_data": {"type": "text"}, # Store as text for searching + "owner": {"type": "keyword"}, + "allowed_users": {"type": "keyword"}, + "allowed_groups": {"type": "keyword"}, + "created_at": {"type": "date"}, + "updated_at": {"type": "date"} + } + } + } + + if not await clients.opensearch.indices.exists(index=contexts_index_name): + await clients.opensearch.indices.create(index=contexts_index_name, body=contexts_index_body) + print(f"Created index '{contexts_index_name}'") + else: + print(f"Index '{contexts_index_name}' already exists, skipping creation.") async def init_index_when_ready(): """Initialize OpenSearch index when it becomes available""" @@ -85,6 +111,7 @@ def initialize_services(): search_service = SearchService(session_manager) task_service = TaskService(document_service, process_pool) chat_service = ChatService() + contexts_service = ContextsService(session_manager) # Set process pool for document service document_service.process_pool = process_pool @@ -109,6 +136,7 @@ def initialize_services(): 'chat_service': chat_service, 'auth_service': auth_service, 'connector_service': connector_service, + 'contexts_service': contexts_service, 'session_manager': session_manager } @@ -170,6 +198,42 @@ def create_app(): session_manager=services['session_manager']) ), methods=["POST"]), + # Contexts endpoints + Route("/contexts", + require_auth(services['session_manager'])( + partial(contexts.create_context, + contexts_service=services['contexts_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/contexts/search", + require_auth(services['session_manager'])( + partial(contexts.search_contexts, + contexts_service=services['contexts_service'], + session_manager=services['session_manager']) + ), methods=["POST"]), + + Route("/contexts/{context_id}", + require_auth(services['session_manager'])( + partial(contexts.get_context, + contexts_service=services['contexts_service'], + session_manager=services['session_manager']) + ), methods=["GET"]), + + Route("/contexts/{context_id}", + require_auth(services['session_manager'])( + partial(contexts.update_context, + contexts_service=services['contexts_service'], + session_manager=services['session_manager']) + ), methods=["PUT"]), + + Route("/contexts/{context_id}", + require_auth(services['session_manager'])( + partial(contexts.delete_context, + contexts_service=services['contexts_service'], + session_manager=services['session_manager']) + ), methods=["DELETE"]), + # Chat endpoints Route("/chat", require_auth(services['session_manager'])( diff --git a/src/services/contexts_service.py b/src/services/contexts_service.py new file mode 100644 index 00000000..5848b5dc --- /dev/null +++ b/src/services/contexts_service.py @@ -0,0 +1,131 @@ +from typing import Any, Dict, Optional + +CONTEXTS_INDEX_NAME = "search_contexts" + +class ContextsService: + def __init__(self, session_manager=None): + self.session_manager = session_manager + + async def create_context(self, context_doc: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: + """Create a new search context""" + try: + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + # Index the context document + result = await opensearch_client.index( + index=CONTEXTS_INDEX_NAME, + id=context_doc["id"], + body=context_doc + ) + + if result.get("result") == "created": + return {"success": True, "id": context_doc["id"], "context": context_doc} + else: + return {"success": False, "error": "Failed to create context"} + + except Exception as e: + return {"success": False, "error": str(e)} + + async def search_contexts(self, query: str, user_id: str = None, jwt_token: str = None, limit: int = 20) -> Dict[str, Any]: + """Search for contexts by name, description, or query content""" + try: + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + if query.strip(): + # Search across name, description, and query_data fields + search_body = { + "query": { + "multi_match": { + "query": query, + "fields": ["name^3", "description^2", "query_data"], + "type": "best_fields", + "fuzziness": "AUTO" + } + }, + "sort": [ + {"_score": {"order": "desc"}}, + {"updated_at": {"order": "desc"}} + ], + "_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"], + "size": limit + } + else: + # No query - return all contexts sorted by most recent + search_body = { + "query": {"match_all": {}}, + "sort": [{"updated_at": {"order": "desc"}}], + "_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"], + "size": limit + } + + result = await opensearch_client.search(index=CONTEXTS_INDEX_NAME, body=search_body) + + # Transform results + contexts = [] + for hit in result["hits"]["hits"]: + context = hit["_source"] + context["score"] = hit.get("_score") + contexts.append(context) + + return {"success": True, "contexts": contexts} + + except Exception as e: + return {"success": False, "error": str(e), "contexts": []} + + async def get_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: + """Get a specific context by ID""" + try: + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + result = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id) + + if result.get("found"): + context = result["_source"] + return {"success": True, "context": context} + else: + return {"success": False, "error": "Context not found"} + + except Exception as e: + return {"success": False, "error": str(e)} + + async def update_context(self, context_id: str, updates: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: + """Update an existing context""" + try: + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + # Update the document + result = await opensearch_client.update( + index=CONTEXTS_INDEX_NAME, + id=context_id, + body={"doc": updates} + ) + + if result.get("result") in ["updated", "noop"]: + # Get the updated document + updated_doc = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id) + return {"success": True, "context": updated_doc["_source"]} + else: + return {"success": False, "error": "Failed to update context"} + + except Exception as e: + return {"success": False, "error": str(e)} + + async def delete_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]: + """Delete a context""" + try: + # Get user's OpenSearch client with JWT for OIDC auth + opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token) + + result = await opensearch_client.delete(index=CONTEXTS_INDEX_NAME, id=context_id) + + if result.get("result") == "deleted": + return {"success": True, "message": "Context deleted successfully"} + else: + return {"success": False, "error": "Failed to delete context"} + + except Exception as e: + return {"success": False, "error": str(e)} \ No newline at end of file