contexts endpoints and chat support via context vars
This commit is contained in:
parent
e87efaaeeb
commit
963c02e6a0
4 changed files with 337 additions and 1 deletions
|
|
@ -7,6 +7,9 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
prompt = data.get("prompt", "")
|
prompt = data.get("prompt", "")
|
||||||
previous_response_id = data.get("previous_response_id")
|
previous_response_id = data.get("previous_response_id")
|
||||||
stream = data.get("stream", False)
|
stream = data.get("stream", False)
|
||||||
|
filters = data.get("filters")
|
||||||
|
limit = data.get("limit", 10)
|
||||||
|
score_threshold = data.get("scoreThreshold", 0)
|
||||||
|
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
user_id = user.user_id
|
user_id = user.user_id
|
||||||
|
|
@ -16,6 +19,15 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||||
|
|
||||||
|
# Set context variables for search tool (similar to search endpoint)
|
||||||
|
if filters:
|
||||||
|
from auth_context import set_search_filters
|
||||||
|
set_search_filters(filters)
|
||||||
|
|
||||||
|
from auth_context import set_search_limit, set_score_threshold
|
||||||
|
set_search_limit(limit)
|
||||||
|
set_score_threshold(score_threshold)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -38,6 +50,12 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
prompt = data.get("prompt", "")
|
prompt = data.get("prompt", "")
|
||||||
previous_response_id = data.get("previous_response_id")
|
previous_response_id = data.get("previous_response_id")
|
||||||
stream = data.get("stream", False)
|
stream = data.get("stream", False)
|
||||||
|
filters = data.get("filters")
|
||||||
|
limit = data.get("limit", 10)
|
||||||
|
score_threshold = data.get("scoreThreshold", 0)
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
user_id = user.user_id
|
||||||
|
|
||||||
# Get JWT token from request cookie
|
# Get JWT token from request cookie
|
||||||
jwt_token = request.cookies.get("auth_token")
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
@ -45,6 +63,15 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||||
|
|
||||||
|
# Set context variables for search tool (similar to chat endpoint)
|
||||||
|
if filters:
|
||||||
|
from auth_context import set_search_filters
|
||||||
|
set_search_filters(filters)
|
||||||
|
|
||||||
|
from auth_context import set_search_limit, set_score_threshold
|
||||||
|
set_search_limit(limit)
|
||||||
|
set_score_threshold(score_threshold)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
|
||||||
114
src/api/contexts.py
Normal file
114
src/api/contexts.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
async def create_context(request: Request, contexts_service, session_manager):
|
||||||
|
"""Create a new search context"""
|
||||||
|
payload = await request.json()
|
||||||
|
|
||||||
|
name = payload.get("name")
|
||||||
|
if not name:
|
||||||
|
return JSONResponse({"error": "Context name is required"}, status_code=400)
|
||||||
|
|
||||||
|
description = payload.get("description", "")
|
||||||
|
query_data = payload.get("queryData")
|
||||||
|
if not query_data:
|
||||||
|
return JSONResponse({"error": "Query data is required"}, status_code=400)
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
|
# Create context document
|
||||||
|
context_id = str(uuid.uuid4())
|
||||||
|
context_doc = {
|
||||||
|
"id": context_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"query_data": query_data, # Store the full search query JSON
|
||||||
|
"owner": user.user_id,
|
||||||
|
"allowed_users": payload.get("allowedUsers", []), # ACL field for future use
|
||||||
|
"allowed_groups": payload.get("allowedGroups", []), # ACL field for future use
|
||||||
|
"created_at": datetime.utcnow().isoformat(),
|
||||||
|
"updated_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await contexts_service.create_context(context_doc, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
async def search_contexts(request: Request, contexts_service, session_manager):
|
||||||
|
"""Search for contexts by name, description, or query content"""
|
||||||
|
payload = await request.json()
|
||||||
|
|
||||||
|
query = payload.get("query", "")
|
||||||
|
limit = payload.get("limit", 20)
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
|
result = await contexts_service.search_contexts(query, user_id=user.user_id, jwt_token=jwt_token, limit=limit)
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
async def get_context(request: Request, contexts_service, session_manager):
|
||||||
|
"""Get a specific context by ID"""
|
||||||
|
context_id = request.path_params.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
|
result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
async def update_context(request: Request, contexts_service, session_manager):
|
||||||
|
"""Update an existing context by delete + recreate (due to DLS limitations)"""
|
||||||
|
context_id = request.path_params.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||||
|
|
||||||
|
payload = await request.json()
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
|
# First, get the existing context
|
||||||
|
existing_result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
if not existing_result.get("success"):
|
||||||
|
return JSONResponse({"error": "Context not found or access denied"}, status_code=404)
|
||||||
|
|
||||||
|
existing_context = existing_result["context"]
|
||||||
|
|
||||||
|
# Delete the existing context
|
||||||
|
delete_result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
if not delete_result.get("success"):
|
||||||
|
return JSONResponse({"error": "Failed to delete existing context"}, status_code=500)
|
||||||
|
|
||||||
|
# Create updated context document with same ID
|
||||||
|
updated_context = {
|
||||||
|
"id": context_id,
|
||||||
|
"name": payload.get("name", existing_context["name"]),
|
||||||
|
"description": payload.get("description", existing_context["description"]),
|
||||||
|
"query_data": payload.get("queryData", existing_context["query_data"]),
|
||||||
|
"owner": existing_context["owner"],
|
||||||
|
"allowed_users": payload.get("allowedUsers", existing_context.get("allowed_users", [])),
|
||||||
|
"allowed_groups": payload.get("allowedGroups", existing_context.get("allowed_groups", [])),
|
||||||
|
"created_at": existing_context["created_at"], # Preserve original creation time
|
||||||
|
"updated_at": datetime.utcnow().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Recreate the context
|
||||||
|
result = await contexts_service.create_context(updated_context, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
async def delete_context(request: Request, contexts_service, session_manager):
|
||||||
|
"""Delete a context"""
|
||||||
|
context_id = request.path_params.get("context_id")
|
||||||
|
if not context_id:
|
||||||
|
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||||
|
|
||||||
|
user = request.state.user
|
||||||
|
jwt_token = request.cookies.get("auth_token")
|
||||||
|
|
||||||
|
result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||||
|
return JSONResponse(result)
|
||||||
66
src/main.py
66
src/main.py
|
|
@ -23,6 +23,7 @@ from services.search_service import SearchService
|
||||||
from services.task_service import TaskService
|
from services.task_service import TaskService
|
||||||
from services.auth_service import AuthService
|
from services.auth_service import AuthService
|
||||||
from services.chat_service import ChatService
|
from services.chat_service import ChatService
|
||||||
|
from services.contexts_service import ContextsService
|
||||||
|
|
||||||
# Existing services
|
# Existing services
|
||||||
from connectors.service import ConnectorService
|
from connectors.service import ConnectorService
|
||||||
|
|
@ -30,7 +31,7 @@ from session_manager import SessionManager
|
||||||
from auth_middleware import require_auth, optional_auth
|
from auth_middleware import require_auth, optional_auth
|
||||||
|
|
||||||
# API endpoints
|
# API endpoints
|
||||||
from api import upload, search, chat, auth, connectors, tasks, oidc
|
from api import upload, search, chat, auth, connectors, tasks, oidc, contexts
|
||||||
|
|
||||||
print("CUDA available:", torch.cuda.is_available())
|
print("CUDA available:", torch.cuda.is_available())
|
||||||
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||||||
|
|
@ -56,11 +57,36 @@ async def init_index():
|
||||||
"""Initialize OpenSearch index and security roles"""
|
"""Initialize OpenSearch index and security roles"""
|
||||||
await wait_for_opensearch()
|
await wait_for_opensearch()
|
||||||
|
|
||||||
|
# Create documents index
|
||||||
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
||||||
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
|
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
|
||||||
print(f"Created index '{INDEX_NAME}'")
|
print(f"Created index '{INDEX_NAME}'")
|
||||||
else:
|
else:
|
||||||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||||||
|
|
||||||
|
# Create contexts index
|
||||||
|
contexts_index_name = "search_contexts"
|
||||||
|
contexts_index_body = {
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"id": {"type": "keyword"},
|
||||||
|
"name": {"type": "text", "analyzer": "standard"},
|
||||||
|
"description": {"type": "text", "analyzer": "standard"},
|
||||||
|
"query_data": {"type": "text"}, # Store as text for searching
|
||||||
|
"owner": {"type": "keyword"},
|
||||||
|
"allowed_users": {"type": "keyword"},
|
||||||
|
"allowed_groups": {"type": "keyword"},
|
||||||
|
"created_at": {"type": "date"},
|
||||||
|
"updated_at": {"type": "date"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if not await clients.opensearch.indices.exists(index=contexts_index_name):
|
||||||
|
await clients.opensearch.indices.create(index=contexts_index_name, body=contexts_index_body)
|
||||||
|
print(f"Created index '{contexts_index_name}'")
|
||||||
|
else:
|
||||||
|
print(f"Index '{contexts_index_name}' already exists, skipping creation.")
|
||||||
|
|
||||||
async def init_index_when_ready():
|
async def init_index_when_ready():
|
||||||
"""Initialize OpenSearch index when it becomes available"""
|
"""Initialize OpenSearch index when it becomes available"""
|
||||||
|
|
@ -85,6 +111,7 @@ def initialize_services():
|
||||||
search_service = SearchService(session_manager)
|
search_service = SearchService(session_manager)
|
||||||
task_service = TaskService(document_service, process_pool)
|
task_service = TaskService(document_service, process_pool)
|
||||||
chat_service = ChatService()
|
chat_service = ChatService()
|
||||||
|
contexts_service = ContextsService(session_manager)
|
||||||
|
|
||||||
# Set process pool for document service
|
# Set process pool for document service
|
||||||
document_service.process_pool = process_pool
|
document_service.process_pool = process_pool
|
||||||
|
|
@ -109,6 +136,7 @@ def initialize_services():
|
||||||
'chat_service': chat_service,
|
'chat_service': chat_service,
|
||||||
'auth_service': auth_service,
|
'auth_service': auth_service,
|
||||||
'connector_service': connector_service,
|
'connector_service': connector_service,
|
||||||
|
'contexts_service': contexts_service,
|
||||||
'session_manager': session_manager
|
'session_manager': session_manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,6 +198,42 @@ def create_app():
|
||||||
session_manager=services['session_manager'])
|
session_manager=services['session_manager'])
|
||||||
), methods=["POST"]),
|
), methods=["POST"]),
|
||||||
|
|
||||||
|
# Contexts endpoints
|
||||||
|
Route("/contexts",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(contexts.create_context,
|
||||||
|
contexts_service=services['contexts_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["POST"]),
|
||||||
|
|
||||||
|
Route("/contexts/search",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(contexts.search_contexts,
|
||||||
|
contexts_service=services['contexts_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["POST"]),
|
||||||
|
|
||||||
|
Route("/contexts/{context_id}",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(contexts.get_context,
|
||||||
|
contexts_service=services['contexts_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["GET"]),
|
||||||
|
|
||||||
|
Route("/contexts/{context_id}",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(contexts.update_context,
|
||||||
|
contexts_service=services['contexts_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["PUT"]),
|
||||||
|
|
||||||
|
Route("/contexts/{context_id}",
|
||||||
|
require_auth(services['session_manager'])(
|
||||||
|
partial(contexts.delete_context,
|
||||||
|
contexts_service=services['contexts_service'],
|
||||||
|
session_manager=services['session_manager'])
|
||||||
|
), methods=["DELETE"]),
|
||||||
|
|
||||||
# Chat endpoints
|
# Chat endpoints
|
||||||
Route("/chat",
|
Route("/chat",
|
||||||
require_auth(services['session_manager'])(
|
require_auth(services['session_manager'])(
|
||||||
|
|
|
||||||
131
src/services/contexts_service.py
Normal file
131
src/services/contexts_service.py
Normal file
|
|
@ -0,0 +1,131 @@
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
CONTEXTS_INDEX_NAME = "search_contexts"
|
||||||
|
|
||||||
|
class ContextsService:
|
||||||
|
def __init__(self, session_manager=None):
|
||||||
|
self.session_manager = session_manager
|
||||||
|
|
||||||
|
async def create_context(self, context_doc: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
|
"""Create a new search context"""
|
||||||
|
try:
|
||||||
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
|
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||||
|
|
||||||
|
# Index the context document
|
||||||
|
result = await opensearch_client.index(
|
||||||
|
index=CONTEXTS_INDEX_NAME,
|
||||||
|
id=context_doc["id"],
|
||||||
|
body=context_doc
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("result") == "created":
|
||||||
|
return {"success": True, "id": context_doc["id"], "context": context_doc}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Failed to create context"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def search_contexts(self, query: str, user_id: str = None, jwt_token: str = None, limit: int = 20) -> Dict[str, Any]:
|
||||||
|
"""Search for contexts by name, description, or query content"""
|
||||||
|
try:
|
||||||
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
|
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||||
|
|
||||||
|
if query.strip():
|
||||||
|
# Search across name, description, and query_data fields
|
||||||
|
search_body = {
|
||||||
|
"query": {
|
||||||
|
"multi_match": {
|
||||||
|
"query": query,
|
||||||
|
"fields": ["name^3", "description^2", "query_data"],
|
||||||
|
"type": "best_fields",
|
||||||
|
"fuzziness": "AUTO"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sort": [
|
||||||
|
{"_score": {"order": "desc"}},
|
||||||
|
{"updated_at": {"order": "desc"}}
|
||||||
|
],
|
||||||
|
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# No query - return all contexts sorted by most recent
|
||||||
|
search_body = {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"sort": [{"updated_at": {"order": "desc"}}],
|
||||||
|
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
|
||||||
|
"size": limit
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await opensearch_client.search(index=CONTEXTS_INDEX_NAME, body=search_body)
|
||||||
|
|
||||||
|
# Transform results
|
||||||
|
contexts = []
|
||||||
|
for hit in result["hits"]["hits"]:
|
||||||
|
context = hit["_source"]
|
||||||
|
context["score"] = hit.get("_score")
|
||||||
|
contexts.append(context)
|
||||||
|
|
||||||
|
return {"success": True, "contexts": contexts}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e), "contexts": []}
|
||||||
|
|
||||||
|
async def get_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
|
"""Get a specific context by ID"""
|
||||||
|
try:
|
||||||
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
|
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||||
|
|
||||||
|
result = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||||
|
|
||||||
|
if result.get("found"):
|
||||||
|
context = result["_source"]
|
||||||
|
return {"success": True, "context": context}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Context not found"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def update_context(self, context_id: str, updates: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
|
"""Update an existing context"""
|
||||||
|
try:
|
||||||
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
|
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||||
|
|
||||||
|
# Update the document
|
||||||
|
result = await opensearch_client.update(
|
||||||
|
index=CONTEXTS_INDEX_NAME,
|
||||||
|
id=context_id,
|
||||||
|
body={"doc": updates}
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("result") in ["updated", "noop"]:
|
||||||
|
# Get the updated document
|
||||||
|
updated_doc = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||||
|
return {"success": True, "context": updated_doc["_source"]}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Failed to update context"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def delete_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||||
|
"""Delete a context"""
|
||||||
|
try:
|
||||||
|
# Get user's OpenSearch client with JWT for OIDC auth
|
||||||
|
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||||
|
|
||||||
|
result = await opensearch_client.delete(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||||
|
|
||||||
|
if result.get("result") == "deleted":
|
||||||
|
return {"success": True, "message": "Context deleted successfully"}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Failed to delete context"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
Loading…
Add table
Reference in a new issue