RBAC basic implementation

This commit is contained in:
Edwin Jose 2025-12-26 16:41:39 -05:00
parent c923ecb396
commit ce51628db2
14 changed files with 409 additions and 98 deletions

View file

@ -6,6 +6,8 @@ import {
export interface CreateApiKeyRequest { export interface CreateApiKeyRequest {
name: string; name: string;
roles?: string[];
groups?: string[];
} }
export interface CreateApiKeyResponse { export interface CreateApiKeyResponse {
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
name: string; name: string;
key_prefix: string; key_prefix: string;
created_at: string; created_at: string;
roles?: string[];
groups?: string[];
} }
export const useCreateApiKeyMutation = ( export const useCreateApiKeyMutation = (

View file

@ -6,6 +6,8 @@ export interface ApiKey {
key_prefix: string; key_prefix: string;
created_at: string; created_at: string;
last_used_at: string | null; last_used_at: string | null;
roles?: string[];
groups?: string[];
} }
export interface GetApiKeysResponse { export interface GetApiKeysResponse {

View file

@ -136,6 +136,7 @@ function KnowledgeSourcesPage() {
// API Keys state // API Keys state
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false); const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
const [newKeyName, setNewKeyName] = useState(""); const [newKeyName, setNewKeyName] = useState("");
const [newKeyGroups, setNewKeyGroups] = useState("");
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null); const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false); const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
@ -156,6 +157,7 @@ function KnowledgeSourcesPage() {
setCreateKeyDialogOpen(false); setCreateKeyDialogOpen(false);
setShowKeyDialogOpen(true); setShowKeyDialogOpen(true);
setNewKeyName(""); setNewKeyName("");
setNewKeyGroups("");
toast.success("API key created"); toast.success("API key created");
}, },
onError: (error) => { onError: (error) => {
@ -438,7 +440,16 @@ function KnowledgeSourcesPage() {
toast.error("Please enter a name for the API key"); toast.error("Please enter a name for the API key");
return; return;
} }
createApiKeyMutation.mutate({ name: newKeyName.trim() }); // Parse groups from comma-separated string
const groups = newKeyGroups
.split(",")
.map((g) => g.trim())
.filter((g) => g.length > 0);
createApiKeyMutation.mutate({
name: newKeyName.trim(),
groups: groups.length > 0 ? groups : undefined,
});
}; };
const handleRevokeApiKey = (keyId: string) => { const handleRevokeApiKey = (keyId: string) => {
@ -1426,6 +1437,9 @@ function KnowledgeSourcesPage() {
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3"> <th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Key Key
</th> </th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Groups
</th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3"> <th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Created Created
</th> </th>
@ -1448,6 +1462,24 @@ function KnowledgeSourcesPage() {
{key.key_prefix}... {key.key_prefix}...
</code> </code>
</td> </td>
<td className="px-4 py-3 text-sm text-muted-foreground">
{key.groups && key.groups.length > 0 ? (
<div className="flex flex-wrap gap-1">
{key.groups.map((group: string) => (
<span
key={group}
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
>
{group}
</span>
))}
</div>
) : (
<span className="text-muted-foreground/50">
All groups
</span>
)}
</td>
<td className="px-4 py-3 text-sm text-muted-foreground"> <td className="px-4 py-3 text-sm text-muted-foreground">
{formatDate(key.created_at)} {formatDate(key.created_at)}
</td> </td>
@ -1507,58 +1539,77 @@ function KnowledgeSourcesPage() {
</Card> </Card>
)} )}
{/* Create API Key Dialog */} {/* Create API Key Dialog */}
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}> <Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
<DialogContent> <DialogContent>
<DialogHeader> <DialogHeader>
<DialogTitle>Create API Key</DialogTitle> <DialogTitle>Create API Key</DialogTitle>
<DialogDescription> <DialogDescription>
Give your API key a name to help you identify it later. Create an API key with optional group restrictions for access
</DialogDescription> control.
</DialogHeader> </DialogDescription>
<div className="py-4"> </DialogHeader>
<LabelWrapper label="Name" id="api-key-name"> <div className="py-4 space-y-4">
<Input <LabelWrapper label="Name" id="api-key-name">
id="api-key-name" <Input
placeholder="e.g., Production App, Development" id="api-key-name"
value={newKeyName} placeholder="e.g., Production App, Development"
onChange={(e) => setNewKeyName(e.target.value)} value={newKeyName}
onKeyDown={(e) => { onChange={(e) => setNewKeyName(e.target.value)}
if (e.key === "Enter") { onKeyDown={(e) => {
handleCreateApiKey(); if (e.key === "Enter") {
} handleCreateApiKey();
}} }
/>
</LabelWrapper>
</div>
<DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
}} }}
size="sm" />
> </LabelWrapper>
Cancel <LabelWrapper
</Button> label="Groups (optional)"
<Button id="api-key-groups"
onClick={handleCreateApiKey} helperText="Comma-separated list of groups this key can access"
disabled={createApiKeyMutation.isPending || !newKeyName.trim()} >
size="sm" <Input
> id="api-key-groups"
{createApiKeyMutation.isPending ? ( placeholder="e.g., finance, hr, engineering"
<> value={newKeyGroups}
<Loader2 className="h-4 w-4 mr-2 animate-spin" /> onChange={(e) => setNewKeyGroups(e.target.value)}
Creating... onKeyDown={(e) => {
</> if (e.key === "Enter") {
) : ( handleCreateApiKey();
"Create Key" }
)} }}
</Button> />
</DialogFooter> </LabelWrapper>
</DialogContent> </div>
</Dialog> <DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
setNewKeyGroups("");
}}
size="sm"
>
Cancel
</Button>
<Button
onClick={handleCreateApiKey}
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
size="sm"
>
{createApiKeyMutation.isPending ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create Key"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Show Created API Key Dialog */} {/* Show Created API Key Dialog */}
<Dialog <Dialog

View file

@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
set_search_limit(limit) set_search_limit(limit)
set_score_threshold(score_threshold) set_score_threshold(score_threshold)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if stream: if stream:
return StreamingResponse( return StreamingResponse(
await chat_service.chat( await chat_service.chat(
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
stream=True, stream=True,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
stream=False, stream=False,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
) )
return JSONResponse(result) return JSONResponse(result)
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token) jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if not prompt: if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400) return JSONResponse({"error": "Prompt is required"}, status_code=400)
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
stream=True, stream=True,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={ headers={
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
stream=False, stream=False,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
) )
return JSONResponse(result) return JSONResponse(result)

View file

@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
async def create_key_endpoint(request: Request, api_key_service): async def create_key_endpoint(request: Request, api_key_service):
""" """
Create a new API key for the authenticated user. Create a new API key for the authenticated user with optional RBAC restrictions.
POST /keys POST /keys
Body: {"name": "My API Key"} Body: {
"name": "My API Key",
"roles": ["openrag_user"], // Optional: restrict key to specific roles
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
}
Response: Response:
{ {
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
"key_id": "...", "key_id": "...",
"key_prefix": "orag_abc12345", "key_prefix": "orag_abc12345",
"name": "My API Key", "name": "My API Key",
"roles": ["openrag_user"],
"groups": ["finance", "hr"],
"created_at": "2024-01-01T00:00:00", "created_at": "2024-01-01T00:00:00",
"api_key": "orag_abc12345..." // Full key, only shown once! "api_key": "orag_abc12345..." // Full key, only shown once!
} }
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
status_code=400, status_code=400,
) )
# Extract optional RBAC fields
roles = data.get("roles")
groups = data.get("groups")
# Validate roles and groups are lists if provided
if roles is not None and not isinstance(roles, list):
return JSONResponse(
{"success": False, "error": "roles must be a list"},
status_code=400,
)
if groups is not None and not isinstance(groups, list):
return JSONResponse(
{"success": False, "error": "groups must be a list"},
status_code=400,
)
result = await api_key_service.create_key( result = await api_key_service.create_key(
user_id=user_id, user_id=user_id,
user_email=user_email, user_email=user_email,
name=name, name=name,
jwt_token=jwt_token, jwt_token=jwt_token,
roles=roles,
groups=groups,
) )
if result.get("success"): if result.get("success"):

View file

@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
user = request.state.user user = request.state.user
user_id = user.user_id user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None) jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
# Set context variables for search tool # Set context variables for search tool
if filters: if filters:
set_search_filters(filters) set_search_filters(filters)
set_search_limit(limit) set_search_limit(limit)
set_score_threshold(score_threshold) set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token) set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
if stream: if stream:
raw_stream = await chat_service.langflow_chat( raw_stream = await chat_service.langflow_chat(
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id, previous_response_id=chat_id,
stream=True, stream=True,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
) )
chat_id_container = {} chat_id_container = {}
return StreamingResponse( return StreamingResponse(
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id, previous_response_id=chat_id,
stream=False, stream=False,
filter_id=filter_id, filter_id=filter_id,
groups=user_groups,
roles=user_roles,
) )
# Transform response_id to chat_id for v1 API format # Transform response_id to chat_id for v1 API format
return JSONResponse({ return JSONResponse({

View file

@ -1,5 +1,8 @@
""" """
API Key middleware for authenticating public API requests. API Key middleware for authenticating public API requests.
This middleware validates API keys and generates ephemeral JWTs with the
key's specific roles and groups for downstream security enforcement.
""" """
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
return None return None
def require_api_key(api_key_service): def require_api_key(api_key_service, session_manager=None):
""" """
Decorator to require API key authentication for public API endpoints. Decorator to require API key authentication for public API endpoints.
Generates an ephemeral JWT with the API key's specific roles and groups
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
Usage: Usage:
@require_api_key(api_key_service) @require_api_key(api_key_service, session_manager)
async def my_endpoint(request): async def my_endpoint(request):
user = request.state.user user = request.state.user
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
... ...
""" """
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
status_code=401, status_code=401,
) )
# Validate the key # Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key) user_info = await api_key_service.validate_key(api_key)
if not user_info: if not user_info:
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
status_code=401, status_code=401,
) )
# Create a User object from the API key info # Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User( user = User(
user_id=user_info["user_id"], user_id=user_info["user_id"],
email=user_info["user_email"], email=user_info["user_email"],
name=user_info.get("name", "API User"), name=user_info.get("name", "API User"),
picture=None, picture=None,
provider="api_key", provider="api_key",
roles=key_roles,
groups=key_groups,
) )
# Set request state # Set request state
request.state.user = user request.state.user = user
request.state.api_key_id = user_info["key_id"] request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None # No JWT for API key auth
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
# Create a short-lived JWT with the key's specific permissions
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1, # Short-lived for API requests
)
request.state.jwt_token = ephemeral_jwt
logger.debug(
"Generated ephemeral JWT for API key",
key_id=user_info["key_id"],
roles=key_roles,
groups=key_groups,
)
else:
request.state.jwt_token = None
logger.warning(
"No session_manager provided - JWT not generated for API key"
)
return await handler(request) return await handler(request)
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
return decorator return decorator
def optional_api_key(api_key_service): def optional_api_key(api_key_service, session_manager=None):
""" """
Decorator to optionally authenticate with API key. Decorator to optionally authenticate with API key.
Sets request.state.user to None if no valid API key is provided. Sets request.state.user to None if no valid API key is provided.
When a valid API key is provided, generates an ephemeral JWT with
the key's specific roles and groups.
""" """
def decorator(handler): def decorator(handler):
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
api_key = _extract_api_key(request) api_key = _extract_api_key(request)
if api_key: if api_key:
# Validate the key # Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key) user_info = await api_key_service.validate_key(api_key)
if user_info: if user_info:
# Create a User object from the API key info # Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User( user = User(
user_id=user_info["user_id"], user_id=user_info["user_id"],
email=user_info["user_email"], email=user_info["user_email"],
name=user_info.get("name", "API User"), name=user_info.get("name", "API User"),
picture=None, picture=None,
provider="api_key", provider="api_key",
roles=key_roles,
groups=key_groups,
) )
request.state.user = user request.state.user = user
request.state.api_key_id = user_info["key_id"] request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1,
)
request.state.jwt_token = ephemeral_jwt
else:
request.state.jwt_token = None
else: else:
request.state.user = None request.state.user = None
request.state.api_key_id = None request.state.api_key_id = None

View file

@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
""" """
from contextvars import ContextVar from contextvars import ContextVar
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, List
# Context variables for current request authentication # Context variables for current request authentication
_current_user_id: ContextVar[Optional[str]] = ContextVar( _current_user_id: ContextVar[Optional[str]] = ContextVar(
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
_current_jwt_token: ContextVar[Optional[str]] = ContextVar( _current_jwt_token: ContextVar[Optional[str]] = ContextVar(
"current_jwt_token", default=None "current_jwt_token", default=None
) )
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_groups", default=None
)
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_roles", default=None
)
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar( _current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"current_search_filters", default=None "current_search_filters", default=None
) )
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
) )
def set_auth_context(user_id: str, jwt_token: str): def set_auth_context(
"""Set authentication context for the current async context""" user_id: str,
jwt_token: str,
groups: Optional[List[str]] = None,
roles: Optional[List[str]] = None,
):
"""Set authentication context for the current async context
Args:
user_id: The user's ID
jwt_token: The JWT token for authentication
groups: Optional list of groups the user belongs to (for RBAC)
roles: Optional list of roles the user has (for RBAC)
"""
_current_user_id.set(user_id) _current_user_id.set(user_id)
_current_jwt_token.set(jwt_token) _current_jwt_token.set(jwt_token)
_current_user_groups.set(groups or [])
_current_user_roles.set(roles or [])
def get_current_user_id() -> Optional[str]: def get_current_user_id() -> Optional[str]:
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
return _current_jwt_token.get() return _current_jwt_token.get()
def get_current_user_groups() -> List[str]:
"""Get current user's groups from context (for RBAC)"""
return _current_user_groups.get() or []
def get_current_user_roles() -> List[str]:
"""Get current user's roles from context (for RBAC)"""
return _current_user_roles.get() or []
def get_auth_context() -> tuple[Optional[str], Optional[str]]: def get_auth_context() -> tuple[Optional[str], Optional[str]]:
"""Get current authentication context (user_id, jwt_token)""" """Get current authentication context (user_id, jwt_token)"""
return _current_user_id.get(), _current_jwt_token.get() return _current_user_id.get(), _current_jwt_token.get()

View file

@ -1314,7 +1314,7 @@ async def create_app():
# Chat endpoints # Chat endpoints
Route( Route(
"/v1/chat", "/v1/chat",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_chat.chat_create_endpoint, v1_chat.chat_create_endpoint,
chat_service=services["chat_service"], chat_service=services["chat_service"],
@ -1325,7 +1325,7 @@ async def create_app():
), ),
Route( Route(
"/v1/chat", "/v1/chat",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_chat.chat_list_endpoint, v1_chat.chat_list_endpoint,
chat_service=services["chat_service"], chat_service=services["chat_service"],
@ -1336,7 +1336,7 @@ async def create_app():
), ),
Route( Route(
"/v1/chat/{chat_id}", "/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_chat.chat_get_endpoint, v1_chat.chat_get_endpoint,
chat_service=services["chat_service"], chat_service=services["chat_service"],
@ -1347,7 +1347,7 @@ async def create_app():
), ),
Route( Route(
"/v1/chat/{chat_id}", "/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_chat.chat_delete_endpoint, v1_chat.chat_delete_endpoint,
chat_service=services["chat_service"], chat_service=services["chat_service"],
@ -1359,7 +1359,7 @@ async def create_app():
# Search endpoint # Search endpoint
Route( Route(
"/v1/search", "/v1/search",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_search.search_endpoint, v1_search.search_endpoint,
search_service=services["search_service"], search_service=services["search_service"],
@ -1371,7 +1371,7 @@ async def create_app():
# Documents endpoints # Documents endpoints
Route( Route(
"/v1/documents/ingest", "/v1/documents/ingest",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_documents.ingest_endpoint, v1_documents.ingest_endpoint,
document_service=services["document_service"], document_service=services["document_service"],
@ -1384,7 +1384,7 @@ async def create_app():
), ),
Route( Route(
"/v1/tasks/{task_id}", "/v1/tasks/{task_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_documents.task_status_endpoint, v1_documents.task_status_endpoint,
task_service=services["task_service"], task_service=services["task_service"],
@ -1395,7 +1395,7 @@ async def create_app():
), ),
Route( Route(
"/v1/documents", "/v1/documents",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_documents.delete_document_endpoint, v1_documents.delete_document_endpoint,
document_service=services["document_service"], document_service=services["document_service"],
@ -1407,14 +1407,14 @@ async def create_app():
# Settings endpoints # Settings endpoints
Route( Route(
"/v1/settings", "/v1/settings",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial(v1_settings.get_settings_endpoint) partial(v1_settings.get_settings_endpoint)
), ),
methods=["GET"], methods=["GET"],
), ),
Route( Route(
"/v1/settings", "/v1/settings",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_settings.update_settings_endpoint, v1_settings.update_settings_endpoint,
session_manager=services["session_manager"], session_manager=services["session_manager"],
@ -1425,7 +1425,7 @@ async def create_app():
# Knowledge filters endpoints # Knowledge filters endpoints
Route( Route(
"/v1/knowledge-filters", "/v1/knowledge-filters",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_knowledge_filters.create_endpoint, v1_knowledge_filters.create_endpoint,
knowledge_filter_service=services["knowledge_filter_service"], knowledge_filter_service=services["knowledge_filter_service"],
@ -1436,7 +1436,7 @@ async def create_app():
), ),
Route( Route(
"/v1/knowledge-filters/search", "/v1/knowledge-filters/search",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_knowledge_filters.search_endpoint, v1_knowledge_filters.search_endpoint,
knowledge_filter_service=services["knowledge_filter_service"], knowledge_filter_service=services["knowledge_filter_service"],
@ -1447,7 +1447,7 @@ async def create_app():
), ),
Route( Route(
"/v1/knowledge-filters/{filter_id}", "/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_knowledge_filters.get_endpoint, v1_knowledge_filters.get_endpoint,
knowledge_filter_service=services["knowledge_filter_service"], knowledge_filter_service=services["knowledge_filter_service"],
@ -1458,7 +1458,7 @@ async def create_app():
), ),
Route( Route(
"/v1/knowledge-filters/{filter_id}", "/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_knowledge_filters.update_endpoint, v1_knowledge_filters.update_endpoint,
knowledge_filter_service=services["knowledge_filter_service"], knowledge_filter_service=services["knowledge_filter_service"],
@ -1469,7 +1469,7 @@ async def create_app():
), ),
Route( Route(
"/v1/knowledge-filters/{filter_id}", "/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])( require_api_key(services["api_key_service"], services["session_manager"])(
partial( partial(
v1_knowledge_filters.delete_endpoint, v1_knowledge_filters.delete_endpoint,
knowledge_filter_service=services["knowledge_filter_service"], knowledge_filter_service=services["knowledge_filter_service"],

View file

@ -158,6 +158,7 @@ class TaskProcessor:
connector_type: str = "local", connector_type: str = "local",
embedding_model: str = None, embedding_model: str = None,
is_sample_data: bool = False, is_sample_data: bool = False,
allowed_groups: list = None,
): ):
""" """
Standard processing pipeline for non-Langflow processors: Standard processing pipeline for non-Langflow processors:
@ -166,6 +167,8 @@ class TaskProcessor:
Args: Args:
embedding_model: Embedding model to use (defaults to the current embedding_model: Embedding model to use (defaults to the current
embedding model from settings) embedding model from settings)
allowed_groups: List of groups that can access this document (RBAC).
Empty list or None means no group restrictions.
""" """
import datetime import datetime
from config.settings import INDEX_NAME, clients, get_embedding_model from config.settings import INDEX_NAME, clients, get_embedding_model
@ -259,6 +262,11 @@ class TaskProcessor:
if owner_email is not None: if owner_email is not None:
chunk_doc["owner_email"] = owner_email chunk_doc["owner_email"] = owner_email
# RBAC: Set allowed groups for access control
# If allowed_groups is provided and non-empty, store it for DLS filtering
if allowed_groups:
chunk_doc["allowed_groups"] = allowed_groups
# Mark as sample data if specified # Mark as sample data if specified
if is_sample_data: if is_sample_data:
chunk_doc["is_sample_data"] = "true" chunk_doc["is_sample_data"] = "true"
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
owner_name: str = None, owner_name: str = None,
owner_email: str = None, owner_email: str = None,
is_sample_data: bool = False, is_sample_data: bool = False,
allowed_groups: list = None,
): ):
super().__init__(document_service) super().__init__(document_service)
self.owner_user_id = owner_user_id self.owner_user_id = owner_user_id
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
self.owner_name = owner_name self.owner_name = owner_name
self.owner_email = owner_email self.owner_email = owner_email
self.is_sample_data = is_sample_data self.is_sample_data = is_sample_data
self.allowed_groups = allowed_groups
async def process_item( async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask self, upload_task: UploadTask, item: str, file_task: FileTask
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
file_size=file_size, file_size=file_size,
connector_type="local", connector_type="local",
is_sample_data=self.is_sample_data, is_sample_data=self.is_sample_data,
allowed_groups=self.allowed_groups,
) )
file_task.status = TaskStatus.COMPLETED file_task.status = TaskStatus.COMPLETED
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_name: str = None, owner_name: str = None,
owner_email: str = None, owner_email: str = None,
document_service=None, document_service=None,
allowed_groups: list = None,
): ):
super().__init__(document_service=document_service) super().__init__(document_service=document_service)
self.connector_service = connector_service self.connector_service = connector_service
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token self.jwt_token = jwt_token
self.owner_name = owner_name self.owner_name = owner_name
self.owner_email = owner_email self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item( async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask self, upload_task: UploadTask, item: str, file_task: FileTask
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_email=self.owner_email, owner_email=self.owner_email,
file_size=len(document.content), file_size=len(document.content),
connector_type=connection.connector_type, connector_type=connection.connector_type,
allowed_groups=self.allowed_groups,
) )
# Add connector-specific metadata # Add connector-specific metadata
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
jwt_token: str = None, jwt_token: str = None,
owner_name: str = None, owner_name: str = None,
owner_email: str = None, owner_email: str = None,
allowed_groups: list = None,
): ):
super().__init__() super().__init__()
self.langflow_connector_service = langflow_connector_service self.langflow_connector_service = langflow_connector_service
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token self.jwt_token = jwt_token
self.owner_name = owner_name self.owner_name = owner_name
self.owner_email = owner_email self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item( async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask self, upload_task: UploadTask, item: str, file_task: FileTask
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
jwt_token: str = None, jwt_token: str = None,
owner_name: str = None, owner_name: str = None,
owner_email: str = None, owner_email: str = None,
allowed_groups: list = None,
): ):
import boto3 import boto3
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
self.jwt_token = jwt_token self.jwt_token = jwt_token
self.owner_name = owner_name self.owner_name = owner_name
self.owner_email = owner_email self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item( async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask self, upload_task: UploadTask, item: str, file_task: FileTask
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
owner_email=self.owner_email, owner_email=self.owner_email,
file_size=file_size, file_size=file_size,
connector_type="s3", connector_type="s3",
allowed_groups=self.allowed_groups,
) )
result["path"] = f"s3://{self.bucket}/{item}" result["path"] = f"s3://{self.bucket}/{item}"
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
settings: dict = None, settings: dict = None,
delete_after_ingest: bool = True, delete_after_ingest: bool = True,
replace_duplicates: bool = False, replace_duplicates: bool = False,
allowed_groups: list = None,
): ):
super().__init__() super().__init__()
self.langflow_file_service = langflow_file_service self.langflow_file_service = langflow_file_service
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
self.settings = settings self.settings = settings
self.delete_after_ingest = delete_after_ingest self.delete_after_ingest = delete_after_ingest
self.replace_duplicates = replace_duplicates self.replace_duplicates = replace_duplicates
self.allowed_groups = allowed_groups
async def process_item( async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask self, upload_task: UploadTask, item: str, file_task: FileTask
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email}) metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
# Mark as local upload for connector_type # Mark as local upload for connector_type
metadata_tweaks.append({"key": "connector_type", "value": "local"}) metadata_tweaks.append({"key": "connector_type", "value": "local"})
# RBAC: Add allowed_groups for access control
if self.allowed_groups:
# Store as comma-separated string for Langflow metadata
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
if metadata_tweaks: if metadata_tweaks:
# Initialize the OpenSearch component tweaks if not already present # Initialize the OpenSearch component tweaks if not already present

View file

@ -52,15 +52,19 @@ class APIKeyService:
user_email: str, user_email: str,
name: str, name: str,
jwt_token: str = None, jwt_token: str = None,
roles: List[str] = None,
groups: List[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Create a new API key for a user. Create a new API key for a user with optional RBAC restrictions.
Args: Args:
user_id: The user's ID user_id: The user's ID
user_email: The user's email user_email: The user's email
name: A friendly name for the key name: A friendly name for the key
jwt_token: JWT token for OpenSearch authentication jwt_token: JWT token for OpenSearch authentication
roles: Optional list of roles to restrict this key to
groups: Optional list of groups this key can access
Returns: Returns:
Dict with success status, key info, and the full key (only shown once) Dict with success status, key info, and the full key (only shown once)
@ -74,6 +78,14 @@ class APIKeyService:
now = datetime.utcnow().isoformat() now = datetime.utcnow().isoformat()
# Default roles if not specified
if roles is None:
roles = ["openrag_user"]
# Default groups to empty if not specified
if groups is None:
groups = []
# Create the document to store # Create the document to store
key_doc = { key_doc = {
"key_id": key_id, "key_id": key_id,
@ -85,6 +97,9 @@ class APIKeyService:
"created_at": now, "created_at": now,
"last_used_at": None, "last_used_at": None,
"revoked": False, "revoked": False,
# RBAC fields
"roles": roles,
"groups": groups,
} }
# Get OpenSearch client # Get OpenSearch client
@ -105,6 +120,8 @@ class APIKeyService:
user_id=user_id, user_id=user_id,
key_id=key_id, key_id=key_id,
key_prefix=key_prefix, key_prefix=key_prefix,
roles=roles,
groups=groups,
) )
return { return {
"success": True, "success": True,
@ -112,6 +129,8 @@ class APIKeyService:
"key_prefix": key_prefix, "key_prefix": key_prefix,
"name": name, "name": name,
"created_at": now, "created_at": now,
"roles": roles,
"groups": groups,
"api_key": full_key, # Only returned once! "api_key": full_key, # Only returned once!
} }
else: else:
@ -123,13 +142,13 @@ class APIKeyService:
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]: async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
""" """
Validate an API key and return user info if valid. Validate an API key and return user info with RBAC claims if valid.
Args: Args:
api_key: The full API key to validate api_key: The full API key to validate
Returns: Returns:
Dict with user info if valid, None if invalid Dict with user info including roles and groups if valid, None if invalid
""" """
try: try:
# Check key format # Check key format
@ -181,11 +200,15 @@ class APIKeyService:
except Exception: except Exception:
pass # Don't fail validation if update fails pass # Don't fail validation if update fails
# Return user info with RBAC claims
return { return {
"key_id": key_doc["key_id"], "key_id": key_doc["key_id"],
"user_id": key_doc["user_id"], "user_id": key_doc["user_id"],
"user_email": key_doc["user_email"], "user_email": key_doc["user_email"],
"name": key_doc["name"], "name": key_doc["name"],
# RBAC fields - provide defaults for backward compatibility
"roles": key_doc.get("roles", ["openrag_user"]),
"groups": key_doc.get("groups", []),
} }
except Exception as e: except Exception as e:
@ -225,6 +248,8 @@ class APIKeyService:
"created_at", "created_at",
"last_used_at", "last_used_at",
"revoked", "revoked",
"roles",
"groups",
], ],
"size": 100, "size": 100,
} }

View file

@ -16,6 +16,8 @@ class ChatService:
previous_response_id: str = None, previous_response_id: str = None,
stream: bool = False, stream: bool = False,
filter_id: str = None, filter_id: str = None,
groups: list = None,
roles: list = None,
): ):
"""Handle chat requests using the patched OpenAI client""" """Handle chat requests using the patched OpenAI client"""
if not prompt: if not prompt:
@ -23,7 +25,7 @@ class ChatService:
# Set authentication context for this request so tools can access it # Set authentication context for this request so tools can access it
if user_id and jwt_token: if user_id and jwt_token:
set_auth_context(user_id, jwt_token) set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
if stream: if stream:
return async_chat_stream( return async_chat_stream(
@ -54,6 +56,8 @@ class ChatService:
previous_response_id: str = None, previous_response_id: str = None,
stream: bool = False, stream: bool = False,
filter_id: str = None, filter_id: str = None,
groups: list = None,
roles: list = None,
): ):
"""Handle Langflow chat requests""" """Handle Langflow chat requests"""
if not prompt: if not prompt:
@ -346,9 +350,9 @@ class ChatService:
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
) )
else: # chat else: # chat
# Set auth context for chat tools and provide user_id # Set auth context for chat tools and provide user_id with RBAC
if user_id and jwt_token: if user_id and jwt_token:
set_auth_context(user_id, jwt_token) set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
response_text, response_id = await async_chat( response_text, response_id = await async_chat(
clients.patched_llm_client, clients.patched_llm_client,
document_prompt, document_prompt,

View file

@ -2,7 +2,7 @@ import copy
from typing import Any, Dict from typing import Any, Dict
from agentd.tool_decorator import tool from agentd.tool_decorator import tool
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
from auth_context import get_auth_context from auth_context import get_auth_context, get_current_user_groups
from utils.logging_config import get_logger from utils.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -259,8 +259,24 @@ class SearchService:
# Build query body # Build query body
if is_wildcard_match_all: if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope # Match all documents; still allow filters to narrow scope
if filter_clauses: # Also add RBAC group filter for wildcard queries
query_block = {"bool": {"filter": filter_clauses}} wildcard_filters = list(filter_clauses) # Copy existing filters
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
wildcard_filters.append(groups_access_filter)
if wildcard_filters:
query_block = {"bool": {"filter": wildcard_filters}}
else: else:
query_block = {"match_all": {}} query_block = {"match_all": {}}
else: else:
@ -292,6 +308,29 @@ class SearchService:
# Add exists filter to existing filters # Add exists filter to existing filters
all_filters = [*filter_clauses, exists_any_embedding] all_filters = [*filter_clauses, exists_any_embedding]
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
# Documents are accessible if:
# 1. No allowed_groups field exists (backward compatibility, open access)
# 2. User's groups match any of the document's allowed_groups
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
# Document has no allowed_groups restriction
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
# User's groups match document's allowed_groups
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
all_filters.append(groups_access_filter)
logger.debug(
"Added RBAC group filter",
user_groups=user_groups,
)
logger.debug( logger.debug(
"Building hybrid query with filters", "Building hybrid query with filters",
user_filters_count=len(filter_clauses), user_filters_count=len(filter_clauses),

View file

@ -2,8 +2,8 @@ import json
import jwt import jwt
import httpx import httpx
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional, Any from typing import Dict, Optional, Any, List
from dataclasses import dataclass, asdict from dataclasses import dataclass, field
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
import os import os
from utils.logging_config import get_logger from utils.logging_config import get_logger
@ -24,12 +24,19 @@ class User:
provider: str = "google" provider: str = "google"
created_at: datetime = None created_at: datetime = None
last_login: datetime = None last_login: datetime = None
# RBAC fields
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
groups: List[str] = field(default_factory=list)
def __post_init__(self): def __post_init__(self):
if self.created_at is None: if self.created_at is None:
self.created_at = datetime.now() self.created_at = datetime.now()
if self.last_login is None: if self.last_login is None:
self.last_login = datetime.now() self.last_login = datetime.now()
if self.roles is None:
self.roles = ["openrag_user"]
if self.groups is None:
self.groups = []
class AnonymousUser(User): class AnonymousUser(User):
"""Anonymous user""" """Anonymous user"""
@ -136,11 +143,31 @@ class SessionManager:
# Create JWT token using the shared method # Create JWT token using the shared method
return self.create_jwt_token(user) return self.create_jwt_token(user)
def create_jwt_token(self, user: User) -> str: def create_jwt_token(
"""Create JWT token for an existing user""" self,
user: User,
roles: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
expiration_days: int = 7,
) -> str:
"""Create JWT token for an existing user.
Args:
user: The User object to create a token for
roles: Optional roles override (for restricted API key tokens)
groups: Optional groups override (for restricted API key tokens)
expiration_days: Token expiration in days (default 7)
Returns:
Encoded JWT token string
"""
# Use OpenSearch-compatible issuer for OIDC validation # Use OpenSearch-compatible issuer for OIDC validation
oidc_issuer = "http://openrag-backend:8000" oidc_issuer = "http://openrag-backend:8000"
# Use provided roles/groups or fall back to user's defaults
effective_roles = roles if roles is not None else user.roles
effective_groups = groups if groups is not None else user.groups
# Create JWT token with OIDC-compliant claims # Create JWT token with OIDC-compliant claims
now = datetime.utcnow() now = datetime.utcnow()
token_payload = { token_payload = {
@ -148,7 +175,7 @@ class SessionManager:
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC "iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
"sub": user.user_id, # Subject (user ID) "sub": user.user_id, # Subject (user ID)
"aud": ["opensearch", "openrag"], # Audience "aud": ["opensearch", "openrag"], # Audience
"exp": now + timedelta(days=7), # Expiration "exp": now + timedelta(days=expiration_days), # Expiration
"iat": now, # Issued at "iat": now, # Issued at
"auth_time": int(now.timestamp()), # Authentication time "auth_time": int(now.timestamp()), # Authentication time
# Custom claims # Custom claims
@ -157,7 +184,9 @@ class SessionManager:
"name": user.name, "name": user.name,
"preferred_username": user.email, "preferred_username": user.email,
"email_verified": True, "email_verified": True,
"roles": ["openrag_user"], # Backend role for OpenSearch # RBAC claims
"roles": effective_roles, # Backend roles for OpenSearch/tools
"groups": effective_groups, # Group-based access control
} }
token = jwt.encode(token_payload, self.private_key, algorithm="RS256") token = jwt.encode(token_payload, self.private_key, algorithm="RS256")