RBAC basic implementation

This commit is contained in:
Edwin Jose 2025-12-26 16:41:39 -05:00
parent c923ecb396
commit ce51628db2
14 changed files with 409 additions and 98 deletions

View file

@ -6,6 +6,8 @@ import {
export interface CreateApiKeyRequest {
name: string;
roles?: string[];
groups?: string[];
}
export interface CreateApiKeyResponse {
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
name: string;
key_prefix: string;
created_at: string;
roles?: string[];
groups?: string[];
}
export const useCreateApiKeyMutation = (

View file

@ -6,6 +6,8 @@ export interface ApiKey {
key_prefix: string;
created_at: string;
last_used_at: string | null;
roles?: string[];
groups?: string[];
}
export interface GetApiKeysResponse {

View file

@ -136,6 +136,7 @@ function KnowledgeSourcesPage() {
// API Keys state
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
const [newKeyName, setNewKeyName] = useState("");
const [newKeyGroups, setNewKeyGroups] = useState("");
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
@ -156,6 +157,7 @@ function KnowledgeSourcesPage() {
setCreateKeyDialogOpen(false);
setShowKeyDialogOpen(true);
setNewKeyName("");
setNewKeyGroups("");
toast.success("API key created");
},
onError: (error) => {
@ -438,7 +440,16 @@ function KnowledgeSourcesPage() {
toast.error("Please enter a name for the API key");
return;
}
createApiKeyMutation.mutate({ name: newKeyName.trim() });
// Parse groups from comma-separated string
const groups = newKeyGroups
.split(",")
.map((g) => g.trim())
.filter((g) => g.length > 0);
createApiKeyMutation.mutate({
name: newKeyName.trim(),
groups: groups.length > 0 ? groups : undefined,
});
};
const handleRevokeApiKey = (keyId: string) => {
@ -1426,6 +1437,9 @@ function KnowledgeSourcesPage() {
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Key
</th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Groups
</th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Created
</th>
@ -1448,6 +1462,24 @@ function KnowledgeSourcesPage() {
{key.key_prefix}...
</code>
</td>
<td className="px-4 py-3 text-sm text-muted-foreground">
{key.groups && key.groups.length > 0 ? (
<div className="flex flex-wrap gap-1">
{key.groups.map((group: string) => (
<span
key={group}
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
>
{group}
</span>
))}
</div>
) : (
<span className="text-muted-foreground/50">
All groups
</span>
)}
</td>
<td className="px-4 py-3 text-sm text-muted-foreground">
{formatDate(key.created_at)}
</td>
@ -1507,58 +1539,77 @@ function KnowledgeSourcesPage() {
</Card>
)}
{/* Create API Key Dialog */}
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create API Key</DialogTitle>
<DialogDescription>
Give your API key a name to help you identify it later.
</DialogDescription>
</DialogHeader>
<div className="py-4">
<LabelWrapper label="Name" id="api-key-name">
<Input
id="api-key-name"
placeholder="e.g., Production App, Development"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateApiKey();
}
}}
/>
</LabelWrapper>
</div>
<DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
{/* Create API Key Dialog */}
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create API Key</DialogTitle>
<DialogDescription>
Create an API key with optional group restrictions for access
control.
</DialogDescription>
</DialogHeader>
<div className="py-4 space-y-4">
<LabelWrapper label="Name" id="api-key-name">
<Input
id="api-key-name"
placeholder="e.g., Production App, Development"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateApiKey();
}
}}
size="sm"
>
Cancel
</Button>
<Button
onClick={handleCreateApiKey}
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
size="sm"
>
{createApiKeyMutation.isPending ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create Key"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
/>
</LabelWrapper>
<LabelWrapper
label="Groups (optional)"
id="api-key-groups"
helperText="Comma-separated list of groups this key can access"
>
<Input
id="api-key-groups"
placeholder="e.g., finance, hr, engineering"
value={newKeyGroups}
onChange={(e) => setNewKeyGroups(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateApiKey();
}
}}
/>
</LabelWrapper>
</div>
<DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
setNewKeyGroups("");
}}
size="sm"
>
Cancel
</Button>
<Button
onClick={handleCreateApiKey}
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
size="sm"
>
{createApiKeyMutation.isPending ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create Key"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Show Created API Key Dialog */}
<Dialog

View file

@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
set_search_limit(limit)
set_score_threshold(score_threshold)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if stream:
return StreamingResponse(
await chat_service.chat(
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
),
media_type="text/event-stream",
headers={
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
return JSONResponse(result)
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
),
media_type="text/event-stream",
headers={
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
return JSONResponse(result)

View file

@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
async def create_key_endpoint(request: Request, api_key_service):
"""
Create a new API key for the authenticated user.
Create a new API key for the authenticated user with optional RBAC restrictions.
POST /keys
Body: {"name": "My API Key"}
Body: {
"name": "My API Key",
"roles": ["openrag_user"], // Optional: restrict key to specific roles
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
}
Response:
{
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
"key_id": "...",
"key_prefix": "orag_abc12345",
"name": "My API Key",
"roles": ["openrag_user"],
"groups": ["finance", "hr"],
"created_at": "2024-01-01T00:00:00",
"api_key": "orag_abc12345..." // Full key, only shown once!
}
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
status_code=400,
)
# Extract optional RBAC fields
roles = data.get("roles")
groups = data.get("groups")
# Validate roles and groups are lists if provided
if roles is not None and not isinstance(roles, list):
return JSONResponse(
{"success": False, "error": "roles must be a list"},
status_code=400,
)
if groups is not None and not isinstance(groups, list):
return JSONResponse(
{"success": False, "error": "groups must be a list"},
status_code=400,
)
result = await api_key_service.create_key(
user_id=user_id,
user_email=user_email,
name=name,
jwt_token=jwt_token,
roles=roles,
groups=groups,
)
if result.get("success"):

View file

@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
user = request.state.user
user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
# Set context variables for search tool
if filters:
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
if stream:
raw_stream = await chat_service.langflow_chat(
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
chat_id_container = {}
return StreamingResponse(
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
# Transform response_id to chat_id for v1 API format
return JSONResponse({

View file

@ -1,5 +1,8 @@
"""
API Key middleware for authenticating public API requests.
This middleware validates API keys and generates ephemeral JWTs with the
key's specific roles and groups for downstream security enforcement.
"""
from starlette.requests import Request
from starlette.responses import JSONResponse
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
return None
def require_api_key(api_key_service):
def require_api_key(api_key_service, session_manager=None):
"""
Decorator to require API key authentication for public API endpoints.
Generates an ephemeral JWT with the API key's specific roles and groups
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
Usage:
@require_api_key(api_key_service)
@require_api_key(api_key_service, session_manager)
async def my_endpoint(request):
user = request.state.user
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
...
"""
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
status_code=401,
)
# Validate the key
# Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key)
if not user_info:
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
status_code=401,
)
# Create a User object from the API key info
# Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User(
user_id=user_info["user_id"],
email=user_info["user_email"],
name=user_info.get("name", "API User"),
picture=None,
provider="api_key",
roles=key_roles,
groups=key_groups,
)
# Set request state
request.state.user = user
request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None # No JWT for API key auth
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
# Create a short-lived JWT with the key's specific permissions
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1, # Short-lived for API requests
)
request.state.jwt_token = ephemeral_jwt
logger.debug(
"Generated ephemeral JWT for API key",
key_id=user_info["key_id"],
roles=key_roles,
groups=key_groups,
)
else:
request.state.jwt_token = None
logger.warning(
"No session_manager provided - JWT not generated for API key"
)
return await handler(request)
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
return decorator
def optional_api_key(api_key_service):
def optional_api_key(api_key_service, session_manager=None):
"""
Decorator to optionally authenticate with API key.
Sets request.state.user to None if no valid API key is provided.
When a valid API key is provided, generates an ephemeral JWT with
the key's specific roles and groups.
"""
def decorator(handler):
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
api_key = _extract_api_key(request)
if api_key:
# Validate the key
# Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key)
if user_info:
# Create a User object from the API key info
# Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User(
user_id=user_info["user_id"],
email=user_info["user_email"],
name=user_info.get("name", "API User"),
picture=None,
provider="api_key",
roles=key_roles,
groups=key_groups,
)
request.state.user = user
request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1,
)
request.state.jwt_token = ephemeral_jwt
else:
request.state.jwt_token = None
else:
request.state.user = None
request.state.api_key_id = None

View file

@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
"""
from contextvars import ContextVar
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
# Context variables for current request authentication
_current_user_id: ContextVar[Optional[str]] = ContextVar(
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
"current_jwt_token", default=None
)
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_groups", default=None
)
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_roles", default=None
)
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"current_search_filters", default=None
)
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
)
def set_auth_context(user_id: str, jwt_token: str):
"""Set authentication context for the current async context"""
def set_auth_context(
user_id: str,
jwt_token: str,
groups: Optional[List[str]] = None,
roles: Optional[List[str]] = None,
):
"""Set authentication context for the current async context
Args:
user_id: The user's ID
jwt_token: The JWT token for authentication
groups: Optional list of groups the user belongs to (for RBAC)
roles: Optional list of roles the user has (for RBAC)
"""
_current_user_id.set(user_id)
_current_jwt_token.set(jwt_token)
_current_user_groups.set(groups or [])
_current_user_roles.set(roles or [])
def get_current_user_id() -> Optional[str]:
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
return _current_jwt_token.get()
def get_current_user_groups() -> List[str]:
"""Get current user's groups from context (for RBAC)"""
return _current_user_groups.get() or []
def get_current_user_roles() -> List[str]:
"""Get current user's roles from context (for RBAC)"""
return _current_user_roles.get() or []
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
"""Get current authentication context (user_id, jwt_token)"""
return _current_user_id.get(), _current_jwt_token.get()

View file

@ -1314,7 +1314,7 @@ async def create_app():
# Chat endpoints
Route(
"/v1/chat",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_create_endpoint,
chat_service=services["chat_service"],
@ -1325,7 +1325,7 @@ async def create_app():
),
Route(
"/v1/chat",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_list_endpoint,
chat_service=services["chat_service"],
@ -1336,7 +1336,7 @@ async def create_app():
),
Route(
"/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_get_endpoint,
chat_service=services["chat_service"],
@ -1347,7 +1347,7 @@ async def create_app():
),
Route(
"/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_delete_endpoint,
chat_service=services["chat_service"],
@ -1359,7 +1359,7 @@ async def create_app():
# Search endpoint
Route(
"/v1/search",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_search.search_endpoint,
search_service=services["search_service"],
@ -1371,7 +1371,7 @@ async def create_app():
# Documents endpoints
Route(
"/v1/documents/ingest",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.ingest_endpoint,
document_service=services["document_service"],
@ -1384,7 +1384,7 @@ async def create_app():
),
Route(
"/v1/tasks/{task_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.task_status_endpoint,
task_service=services["task_service"],
@ -1395,7 +1395,7 @@ async def create_app():
),
Route(
"/v1/documents",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.delete_document_endpoint,
document_service=services["document_service"],
@ -1407,14 +1407,14 @@ async def create_app():
# Settings endpoints
Route(
"/v1/settings",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(v1_settings.get_settings_endpoint)
),
methods=["GET"],
),
Route(
"/v1/settings",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_settings.update_settings_endpoint,
session_manager=services["session_manager"],
@ -1425,7 +1425,7 @@ async def create_app():
# Knowledge filters endpoints
Route(
"/v1/knowledge-filters",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.create_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1436,7 +1436,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/search",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.search_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1447,7 +1447,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.get_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1458,7 +1458,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.update_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1469,7 +1469,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.delete_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],

View file

@ -158,6 +158,7 @@ class TaskProcessor:
connector_type: str = "local",
embedding_model: str = None,
is_sample_data: bool = False,
allowed_groups: list = None,
):
"""
Standard processing pipeline for non-Langflow processors:
@ -166,6 +167,8 @@ class TaskProcessor:
Args:
embedding_model: Embedding model to use (defaults to the current
embedding model from settings)
allowed_groups: List of groups that can access this document (RBAC).
Empty list or None means no group restrictions.
"""
import datetime
from config.settings import INDEX_NAME, clients, get_embedding_model
@ -259,6 +262,11 @@ class TaskProcessor:
if owner_email is not None:
chunk_doc["owner_email"] = owner_email
# RBAC: Set allowed groups for access control
# If allowed_groups is provided and non-empty, store it for DLS filtering
if allowed_groups:
chunk_doc["allowed_groups"] = allowed_groups
# Mark as sample data if specified
if is_sample_data:
chunk_doc["is_sample_data"] = "true"
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
owner_name: str = None,
owner_email: str = None,
is_sample_data: bool = False,
allowed_groups: list = None,
):
super().__init__(document_service)
self.owner_user_id = owner_user_id
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
self.owner_name = owner_name
self.owner_email = owner_email
self.is_sample_data = is_sample_data
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
file_size=file_size,
connector_type="local",
is_sample_data=self.is_sample_data,
allowed_groups=self.allowed_groups,
)
file_task.status = TaskStatus.COMPLETED
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_name: str = None,
owner_email: str = None,
document_service=None,
allowed_groups: list = None,
):
super().__init__(document_service=document_service)
self.connector_service = connector_service
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_email=self.owner_email,
file_size=len(document.content),
connector_type=connection.connector_type,
allowed_groups=self.allowed_groups,
)
# Add connector-specific metadata
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
jwt_token: str = None,
owner_name: str = None,
owner_email: str = None,
allowed_groups: list = None,
):
super().__init__()
self.langflow_connector_service = langflow_connector_service
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
jwt_token: str = None,
owner_name: str = None,
owner_email: str = None,
allowed_groups: list = None,
):
import boto3
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
owner_email=self.owner_email,
file_size=file_size,
connector_type="s3",
allowed_groups=self.allowed_groups,
)
result["path"] = f"s3://{self.bucket}/{item}"
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
settings: dict = None,
delete_after_ingest: bool = True,
replace_duplicates: bool = False,
allowed_groups: list = None,
):
super().__init__()
self.langflow_file_service = langflow_file_service
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
self.settings = settings
self.delete_after_ingest = delete_after_ingest
self.replace_duplicates = replace_duplicates
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
# Mark as local upload for connector_type
metadata_tweaks.append({"key": "connector_type", "value": "local"})
# RBAC: Add allowed_groups for access control
if self.allowed_groups:
# Store as comma-separated string for Langflow metadata
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
if metadata_tweaks:
# Initialize the OpenSearch component tweaks if not already present

View file

@ -52,15 +52,19 @@ class APIKeyService:
user_email: str,
name: str,
jwt_token: str = None,
roles: List[str] = None,
groups: List[str] = None,
) -> Dict[str, Any]:
"""
Create a new API key for a user.
Create a new API key for a user with optional RBAC restrictions.
Args:
user_id: The user's ID
user_email: The user's email
name: A friendly name for the key
jwt_token: JWT token for OpenSearch authentication
roles: Optional list of roles to restrict this key to
groups: Optional list of groups this key can access
Returns:
Dict with success status, key info, and the full key (only shown once)
@ -74,6 +78,14 @@ class APIKeyService:
now = datetime.utcnow().isoformat()
# Default roles if not specified
if roles is None:
roles = ["openrag_user"]
# Default groups to empty if not specified
if groups is None:
groups = []
# Create the document to store
key_doc = {
"key_id": key_id,
@ -85,6 +97,9 @@ class APIKeyService:
"created_at": now,
"last_used_at": None,
"revoked": False,
# RBAC fields
"roles": roles,
"groups": groups,
}
# Get OpenSearch client
@ -105,6 +120,8 @@ class APIKeyService:
user_id=user_id,
key_id=key_id,
key_prefix=key_prefix,
roles=roles,
groups=groups,
)
return {
"success": True,
@ -112,6 +129,8 @@ class APIKeyService:
"key_prefix": key_prefix,
"name": name,
"created_at": now,
"roles": roles,
"groups": groups,
"api_key": full_key, # Only returned once!
}
else:
@ -123,13 +142,13 @@ class APIKeyService:
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""
Validate an API key and return user info if valid.
Validate an API key and return user info with RBAC claims if valid.
Args:
api_key: The full API key to validate
Returns:
Dict with user info if valid, None if invalid
Dict with user info including roles and groups if valid, None if invalid
"""
try:
# Check key format
@ -181,11 +200,15 @@ class APIKeyService:
except Exception:
pass # Don't fail validation if update fails
# Return user info with RBAC claims
return {
"key_id": key_doc["key_id"],
"user_id": key_doc["user_id"],
"user_email": key_doc["user_email"],
"name": key_doc["name"],
# RBAC fields - provide defaults for backward compatibility
"roles": key_doc.get("roles", ["openrag_user"]),
"groups": key_doc.get("groups", []),
}
except Exception as e:
@ -225,6 +248,8 @@ class APIKeyService:
"created_at",
"last_used_at",
"revoked",
"roles",
"groups",
],
"size": 100,
}

View file

@ -16,6 +16,8 @@ class ChatService:
previous_response_id: str = None,
stream: bool = False,
filter_id: str = None,
groups: list = None,
roles: list = None,
):
"""Handle chat requests using the patched OpenAI client"""
if not prompt:
@ -23,7 +25,7 @@ class ChatService:
# Set authentication context for this request so tools can access it
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
if stream:
return async_chat_stream(
@ -54,6 +56,8 @@ class ChatService:
previous_response_id: str = None,
stream: bool = False,
filter_id: str = None,
groups: list = None,
roles: list = None,
):
"""Handle Langflow chat requests"""
if not prompt:
@ -346,9 +350,9 @@ class ChatService:
previous_response_id=previous_response_id,
)
else: # chat
# Set auth context for chat tools and provide user_id
# Set auth context for chat tools and provide user_id with RBAC
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
response_text, response_id = await async_chat(
clients.patched_llm_client,
document_prompt,

View file

@ -2,7 +2,7 @@ import copy
from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
from auth_context import get_auth_context
from auth_context import get_auth_context, get_current_user_groups
from utils.logging_config import get_logger
logger = get_logger(__name__)
@ -259,8 +259,24 @@ class SearchService:
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
# Also add RBAC group filter for wildcard queries
wildcard_filters = list(filter_clauses) # Copy existing filters
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
wildcard_filters.append(groups_access_filter)
if wildcard_filters:
query_block = {"bool": {"filter": wildcard_filters}}
else:
query_block = {"match_all": {}}
else:
@ -292,6 +308,29 @@ class SearchService:
# Add exists filter to existing filters
all_filters = [*filter_clauses, exists_any_embedding]
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
# Documents are accessible if:
# 1. No allowed_groups field exists (backward compatibility, open access)
# 2. User's groups match any of the document's allowed_groups
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
# Document has no allowed_groups restriction
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
# User's groups match document's allowed_groups
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
all_filters.append(groups_access_filter)
logger.debug(
"Added RBAC group filter",
user_groups=user_groups,
)
logger.debug(
"Building hybrid query with filters",
user_filters_count=len(filter_clauses),

View file

@ -2,8 +2,8 @@ import json
import jwt
import httpx
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict
from typing import Dict, Optional, Any, List
from dataclasses import dataclass, field
from cryptography.hazmat.primitives import serialization
import os
from utils.logging_config import get_logger
@ -24,12 +24,19 @@ class User:
provider: str = "google"
created_at: datetime = None
last_login: datetime = None
# RBAC fields
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
groups: List[str] = field(default_factory=list)
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if self.last_login is None:
self.last_login = datetime.now()
if self.roles is None:
self.roles = ["openrag_user"]
if self.groups is None:
self.groups = []
class AnonymousUser(User):
"""Anonymous user"""
@ -136,11 +143,31 @@ class SessionManager:
# Create JWT token using the shared method
return self.create_jwt_token(user)
def create_jwt_token(self, user: User) -> str:
"""Create JWT token for an existing user"""
def create_jwt_token(
self,
user: User,
roles: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
expiration_days: int = 7,
) -> str:
"""Create JWT token for an existing user.
Args:
user: The User object to create a token for
roles: Optional roles override (for restricted API key tokens)
groups: Optional groups override (for restricted API key tokens)
expiration_days: Token expiration in days (default 7)
Returns:
Encoded JWT token string
"""
# Use OpenSearch-compatible issuer for OIDC validation
oidc_issuer = "http://openrag-backend:8000"
# Use provided roles/groups or fall back to user's defaults
effective_roles = roles if roles is not None else user.roles
effective_groups = groups if groups is not None else user.groups
# Create JWT token with OIDC-compliant claims
now = datetime.utcnow()
token_payload = {
@ -148,7 +175,7 @@ class SessionManager:
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
"sub": user.user_id, # Subject (user ID)
"aud": ["opensearch", "openrag"], # Audience
"exp": now + timedelta(days=7), # Expiration
"exp": now + timedelta(days=expiration_days), # Expiration
"iat": now, # Issued at
"auth_time": int(now.timestamp()), # Authentication time
# Custom claims
@ -157,7 +184,9 @@ class SessionManager:
"name": user.name,
"preferred_username": user.email,
"email_verified": True,
"roles": ["openrag_user"], # Backend role for OpenSearch
# RBAC claims
"roles": effective_roles, # Backend roles for OpenSearch/tools
"groups": effective_groups, # Group-based access control
}
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")