contexts endpoints and chat support via context vars
This commit is contained in:
parent
e87efaaeeb
commit
963c02e6a0
4 changed files with 337 additions and 1 deletions
|
|
@ -7,6 +7,9 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
prompt = data.get("prompt", "")
|
||||
previous_response_id = data.get("previous_response_id")
|
||||
stream = data.get("stream", False)
|
||||
filters = data.get("filters")
|
||||
limit = data.get("limit", 10)
|
||||
score_threshold = data.get("scoreThreshold", 0)
|
||||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
|
@ -16,6 +19,15 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
|
||||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
# Set context variables for search tool (similar to search endpoint)
|
||||
if filters:
|
||||
from auth_context import set_search_filters
|
||||
set_search_filters(filters)
|
||||
|
||||
from auth_context import set_search_limit, set_score_threshold
|
||||
set_search_limit(limit)
|
||||
set_score_threshold(score_threshold)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
|
|
@ -38,6 +50,12 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
prompt = data.get("prompt", "")
|
||||
previous_response_id = data.get("previous_response_id")
|
||||
stream = data.get("stream", False)
|
||||
filters = data.get("filters")
|
||||
limit = data.get("limit", 10)
|
||||
score_threshold = data.get("scoreThreshold", 0)
|
||||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
|
||||
# Get JWT token from request cookie
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
|
@ -45,6 +63,15 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
# Set context variables for search tool (similar to chat endpoint)
|
||||
if filters:
|
||||
from auth_context import set_search_filters
|
||||
set_search_filters(filters)
|
||||
|
||||
from auth_context import set_search_limit, set_score_threshold
|
||||
set_search_limit(limit)
|
||||
set_score_threshold(score_threshold)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
|
|
|
|||
114
src/api/contexts.py
Normal file
114
src/api/contexts.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
async def create_context(request: Request, contexts_service, session_manager):
|
||||
"""Create a new search context"""
|
||||
payload = await request.json()
|
||||
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
return JSONResponse({"error": "Context name is required"}, status_code=400)
|
||||
|
||||
description = payload.get("description", "")
|
||||
query_data = payload.get("queryData")
|
||||
if not query_data:
|
||||
return JSONResponse({"error": "Query data is required"}, status_code=400)
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
# Create context document
|
||||
context_id = str(uuid.uuid4())
|
||||
context_doc = {
|
||||
"id": context_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"query_data": query_data, # Store the full search query JSON
|
||||
"owner": user.user_id,
|
||||
"allowed_users": payload.get("allowedUsers", []), # ACL field for future use
|
||||
"allowed_groups": payload.get("allowedGroups", []), # ACL field for future use
|
||||
"created_at": datetime.utcnow().isoformat(),
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
result = await contexts_service.create_context(context_doc, user_id=user.user_id, jwt_token=jwt_token)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def search_contexts(request: Request, contexts_service, session_manager):
|
||||
"""Search for contexts by name, description, or query content"""
|
||||
payload = await request.json()
|
||||
|
||||
query = payload.get("query", "")
|
||||
limit = payload.get("limit", 20)
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await contexts_service.search_contexts(query, user_id=user.user_id, jwt_token=jwt_token, limit=limit)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def get_context(request: Request, contexts_service, session_manager):
|
||||
"""Get a specific context by ID"""
|
||||
context_id = request.path_params.get("context_id")
|
||||
if not context_id:
|
||||
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def update_context(request: Request, contexts_service, session_manager):
|
||||
"""Update an existing context by delete + recreate (due to DLS limitations)"""
|
||||
context_id = request.path_params.get("context_id")
|
||||
if not context_id:
|
||||
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||
|
||||
payload = await request.json()
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
# First, get the existing context
|
||||
existing_result = await contexts_service.get_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||
if not existing_result.get("success"):
|
||||
return JSONResponse({"error": "Context not found or access denied"}, status_code=404)
|
||||
|
||||
existing_context = existing_result["context"]
|
||||
|
||||
# Delete the existing context
|
||||
delete_result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||
if not delete_result.get("success"):
|
||||
return JSONResponse({"error": "Failed to delete existing context"}, status_code=500)
|
||||
|
||||
# Create updated context document with same ID
|
||||
updated_context = {
|
||||
"id": context_id,
|
||||
"name": payload.get("name", existing_context["name"]),
|
||||
"description": payload.get("description", existing_context["description"]),
|
||||
"query_data": payload.get("queryData", existing_context["query_data"]),
|
||||
"owner": existing_context["owner"],
|
||||
"allowed_users": payload.get("allowedUsers", existing_context.get("allowed_users", [])),
|
||||
"allowed_groups": payload.get("allowedGroups", existing_context.get("allowed_groups", [])),
|
||||
"created_at": existing_context["created_at"], # Preserve original creation time
|
||||
"updated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# Recreate the context
|
||||
result = await contexts_service.create_context(updated_context, user_id=user.user_id, jwt_token=jwt_token)
|
||||
return JSONResponse(result)
|
||||
|
||||
async def delete_context(request: Request, contexts_service, session_manager):
|
||||
"""Delete a context"""
|
||||
context_id = request.path_params.get("context_id")
|
||||
if not context_id:
|
||||
return JSONResponse({"error": "Context ID is required"}, status_code=400)
|
||||
|
||||
user = request.state.user
|
||||
jwt_token = request.cookies.get("auth_token")
|
||||
|
||||
result = await contexts_service.delete_context(context_id, user_id=user.user_id, jwt_token=jwt_token)
|
||||
return JSONResponse(result)
|
||||
66
src/main.py
66
src/main.py
|
|
@ -23,6 +23,7 @@ from services.search_service import SearchService
|
|||
from services.task_service import TaskService
|
||||
from services.auth_service import AuthService
|
||||
from services.chat_service import ChatService
|
||||
from services.contexts_service import ContextsService
|
||||
|
||||
# Existing services
|
||||
from connectors.service import ConnectorService
|
||||
|
|
@ -30,7 +31,7 @@ from session_manager import SessionManager
|
|||
from auth_middleware import require_auth, optional_auth
|
||||
|
||||
# API endpoints
|
||||
from api import upload, search, chat, auth, connectors, tasks, oidc
|
||||
from api import upload, search, chat, auth, connectors, tasks, oidc, contexts
|
||||
|
||||
print("CUDA available:", torch.cuda.is_available())
|
||||
print("CUDA version PyTorch was built with:", torch.version.cuda)
|
||||
|
|
@ -56,11 +57,36 @@ async def init_index():
|
|||
"""Initialize OpenSearch index and security roles"""
|
||||
await wait_for_opensearch()
|
||||
|
||||
# Create documents index
|
||||
if not await clients.opensearch.indices.exists(index=INDEX_NAME):
|
||||
await clients.opensearch.indices.create(index=INDEX_NAME, body=INDEX_BODY)
|
||||
print(f"Created index '{INDEX_NAME}'")
|
||||
else:
|
||||
print(f"Index '{INDEX_NAME}' already exists, skipping creation.")
|
||||
|
||||
# Create contexts index
|
||||
contexts_index_name = "search_contexts"
|
||||
contexts_index_body = {
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"name": {"type": "text", "analyzer": "standard"},
|
||||
"description": {"type": "text", "analyzer": "standard"},
|
||||
"query_data": {"type": "text"}, # Store as text for searching
|
||||
"owner": {"type": "keyword"},
|
||||
"allowed_users": {"type": "keyword"},
|
||||
"allowed_groups": {"type": "keyword"},
|
||||
"created_at": {"type": "date"},
|
||||
"updated_at": {"type": "date"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if not await clients.opensearch.indices.exists(index=contexts_index_name):
|
||||
await clients.opensearch.indices.create(index=contexts_index_name, body=contexts_index_body)
|
||||
print(f"Created index '{contexts_index_name}'")
|
||||
else:
|
||||
print(f"Index '{contexts_index_name}' already exists, skipping creation.")
|
||||
|
||||
async def init_index_when_ready():
|
||||
"""Initialize OpenSearch index when it becomes available"""
|
||||
|
|
@ -85,6 +111,7 @@ def initialize_services():
|
|||
search_service = SearchService(session_manager)
|
||||
task_service = TaskService(document_service, process_pool)
|
||||
chat_service = ChatService()
|
||||
contexts_service = ContextsService(session_manager)
|
||||
|
||||
# Set process pool for document service
|
||||
document_service.process_pool = process_pool
|
||||
|
|
@ -109,6 +136,7 @@ def initialize_services():
|
|||
'chat_service': chat_service,
|
||||
'auth_service': auth_service,
|
||||
'connector_service': connector_service,
|
||||
'contexts_service': contexts_service,
|
||||
'session_manager': session_manager
|
||||
}
|
||||
|
||||
|
|
@ -170,6 +198,42 @@ def create_app():
|
|||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
# Contexts endpoints
|
||||
Route("/contexts",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(contexts.create_context,
|
||||
contexts_service=services['contexts_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/contexts/search",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(contexts.search_contexts,
|
||||
contexts_service=services['contexts_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["POST"]),
|
||||
|
||||
Route("/contexts/{context_id}",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(contexts.get_context,
|
||||
contexts_service=services['contexts_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["GET"]),
|
||||
|
||||
Route("/contexts/{context_id}",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(contexts.update_context,
|
||||
contexts_service=services['contexts_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["PUT"]),
|
||||
|
||||
Route("/contexts/{context_id}",
|
||||
require_auth(services['session_manager'])(
|
||||
partial(contexts.delete_context,
|
||||
contexts_service=services['contexts_service'],
|
||||
session_manager=services['session_manager'])
|
||||
), methods=["DELETE"]),
|
||||
|
||||
# Chat endpoints
|
||||
Route("/chat",
|
||||
require_auth(services['session_manager'])(
|
||||
|
|
|
|||
131
src/services/contexts_service.py
Normal file
131
src/services/contexts_service.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
CONTEXTS_INDEX_NAME = "search_contexts"
|
||||
|
||||
class ContextsService:
|
||||
def __init__(self, session_manager=None):
|
||||
self.session_manager = session_manager
|
||||
|
||||
async def create_context(self, context_doc: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Create a new search context"""
|
||||
try:
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
|
||||
# Index the context document
|
||||
result = await opensearch_client.index(
|
||||
index=CONTEXTS_INDEX_NAME,
|
||||
id=context_doc["id"],
|
||||
body=context_doc
|
||||
)
|
||||
|
||||
if result.get("result") == "created":
|
||||
return {"success": True, "id": context_doc["id"], "context": context_doc}
|
||||
else:
|
||||
return {"success": False, "error": "Failed to create context"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def search_contexts(self, query: str, user_id: str = None, jwt_token: str = None, limit: int = 20) -> Dict[str, Any]:
|
||||
"""Search for contexts by name, description, or query content"""
|
||||
try:
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
|
||||
if query.strip():
|
||||
# Search across name, description, and query_data fields
|
||||
search_body = {
|
||||
"query": {
|
||||
"multi_match": {
|
||||
"query": query,
|
||||
"fields": ["name^3", "description^2", "query_data"],
|
||||
"type": "best_fields",
|
||||
"fuzziness": "AUTO"
|
||||
}
|
||||
},
|
||||
"sort": [
|
||||
{"_score": {"order": "desc"}},
|
||||
{"updated_at": {"order": "desc"}}
|
||||
],
|
||||
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
|
||||
"size": limit
|
||||
}
|
||||
else:
|
||||
# No query - return all contexts sorted by most recent
|
||||
search_body = {
|
||||
"query": {"match_all": {}},
|
||||
"sort": [{"updated_at": {"order": "desc"}}],
|
||||
"_source": ["id", "name", "description", "query_data", "owner", "created_at", "updated_at"],
|
||||
"size": limit
|
||||
}
|
||||
|
||||
result = await opensearch_client.search(index=CONTEXTS_INDEX_NAME, body=search_body)
|
||||
|
||||
# Transform results
|
||||
contexts = []
|
||||
for hit in result["hits"]["hits"]:
|
||||
context = hit["_source"]
|
||||
context["score"] = hit.get("_score")
|
||||
contexts.append(context)
|
||||
|
||||
return {"success": True, "contexts": contexts}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e), "contexts": []}
|
||||
|
||||
async def get_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Get a specific context by ID"""
|
||||
try:
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
|
||||
result = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||
|
||||
if result.get("found"):
|
||||
context = result["_source"]
|
||||
return {"success": True, "context": context}
|
||||
else:
|
||||
return {"success": False, "error": "Context not found"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def update_context(self, context_id: str, updates: Dict[str, Any], user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Update an existing context"""
|
||||
try:
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
|
||||
# Update the document
|
||||
result = await opensearch_client.update(
|
||||
index=CONTEXTS_INDEX_NAME,
|
||||
id=context_id,
|
||||
body={"doc": updates}
|
||||
)
|
||||
|
||||
if result.get("result") in ["updated", "noop"]:
|
||||
# Get the updated document
|
||||
updated_doc = await opensearch_client.get(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||
return {"success": True, "context": updated_doc["_source"]}
|
||||
else:
|
||||
return {"success": False, "error": "Failed to update context"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def delete_context(self, context_id: str, user_id: str = None, jwt_token: str = None) -> Dict[str, Any]:
|
||||
"""Delete a context"""
|
||||
try:
|
||||
# Get user's OpenSearch client with JWT for OIDC auth
|
||||
opensearch_client = self.session_manager.get_user_opensearch_client(user_id, jwt_token)
|
||||
|
||||
result = await opensearch_client.delete(index=CONTEXTS_INDEX_NAME, id=context_id)
|
||||
|
||||
if result.get("result") == "deleted":
|
||||
return {"success": True, "message": "Context deleted successfully"}
|
||||
else:
|
||||
return {"success": False, "error": "Failed to delete context"}
|
||||
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
Loading…
Add table
Reference in a new issue