From ce51628db2c41e7c6aa766bf3eba756909d2770c Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Fri, 26 Dec 2025 16:41:39 -0500 Subject: [PATCH] RBAC basic implementation --- .../api/mutations/useCreateApiKeyMutation.ts | 4 + .../app/api/queries/useGetApiKeysQuery.ts | 2 + frontend/app/settings/page.tsx | 155 ++++++++++++------ src/api/chat.py | 16 ++ src/api/keys.py | 28 +++- src/api/v1/chat.py | 12 +- src/api_key_middleware.py | 72 +++++++- src/auth_context.py | 36 +++- src/main.py | 30 ++-- src/models/processors.py | 25 +++ src/services/api_key_service.py | 31 +++- src/services/chat_service.py | 10 +- src/services/search_service.py | 45 ++++- src/session_manager.py | 41 ++++- 14 files changed, 409 insertions(+), 98 deletions(-) diff --git a/frontend/app/api/mutations/useCreateApiKeyMutation.ts b/frontend/app/api/mutations/useCreateApiKeyMutation.ts index ad79693d..2d7cd1e1 100644 --- a/frontend/app/api/mutations/useCreateApiKeyMutation.ts +++ b/frontend/app/api/mutations/useCreateApiKeyMutation.ts @@ -6,6 +6,8 @@ import { export interface CreateApiKeyRequest { name: string; + roles?: string[]; + groups?: string[]; } export interface CreateApiKeyResponse { @@ -14,6 +16,8 @@ export interface CreateApiKeyResponse { name: string; key_prefix: string; created_at: string; + roles?: string[]; + groups?: string[]; } export const useCreateApiKeyMutation = ( diff --git a/frontend/app/api/queries/useGetApiKeysQuery.ts b/frontend/app/api/queries/useGetApiKeysQuery.ts index 60803ec5..f8e55c5d 100644 --- a/frontend/app/api/queries/useGetApiKeysQuery.ts +++ b/frontend/app/api/queries/useGetApiKeysQuery.ts @@ -6,6 +6,8 @@ export interface ApiKey { key_prefix: string; created_at: string; last_used_at: string | null; + roles?: string[]; + groups?: string[]; } export interface GetApiKeysResponse { diff --git a/frontend/app/settings/page.tsx b/frontend/app/settings/page.tsx index d5e56ca9..fc8b0322 100644 --- a/frontend/app/settings/page.tsx +++ b/frontend/app/settings/page.tsx @@ -136,6 +136,7 @@ function KnowledgeSourcesPage() { // API Keys state const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false); const [newKeyName, setNewKeyName] = useState(""); + const [newKeyGroups, setNewKeyGroups] = useState(""); const [newlyCreatedKey, setNewlyCreatedKey] = useState(null); const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false); @@ -156,6 +157,7 @@ function KnowledgeSourcesPage() { setCreateKeyDialogOpen(false); setShowKeyDialogOpen(true); setNewKeyName(""); + setNewKeyGroups(""); toast.success("API key created"); }, onError: (error) => { @@ -438,7 +440,16 @@ function KnowledgeSourcesPage() { toast.error("Please enter a name for the API key"); return; } - createApiKeyMutation.mutate({ name: newKeyName.trim() }); + // Parse groups from comma-separated string + const groups = newKeyGroups + .split(",") + .map((g) => g.trim()) + .filter((g) => g.length > 0); + + createApiKeyMutation.mutate({ + name: newKeyName.trim(), + groups: groups.length > 0 ? groups : undefined, + }); }; const handleRevokeApiKey = (keyId: string) => { @@ -1426,6 +1437,9 @@ function KnowledgeSourcesPage() { Key + + Groups + Created @@ -1448,6 +1462,24 @@ function KnowledgeSourcesPage() { {key.key_prefix}... + + {key.groups && key.groups.length > 0 ? ( +
+ {key.groups.map((group: string) => ( + + {group} + + ))} +
+ ) : ( + + All groups + + )} + {formatDate(key.created_at)} @@ -1507,58 +1539,77 @@ function KnowledgeSourcesPage() { )} - {/* Create API Key Dialog */} - - - - Create API Key - - Give your API key a name to help you identify it later. - - -
- - setNewKeyName(e.target.value)} - onKeyDown={(e) => { - if (e.key === "Enter") { - handleCreateApiKey(); - } - }} - /> - -
- - - - -
-
+ /> + + + setNewKeyGroups(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter") { + handleCreateApiKey(); + } + }} + /> + + + + + + + + {/* Show Created API Key Dialog */} str | None: return None -def require_api_key(api_key_service): +def require_api_key(api_key_service, session_manager=None): """ Decorator to require API key authentication for public API endpoints. + + Generates an ephemeral JWT with the API key's specific roles and groups + to enforce RBAC in downstream services (OpenSearch, tools, etc.). Usage: - @require_api_key(api_key_service) + @require_api_key(api_key_service, session_manager) async def my_endpoint(request): user = request.state.user + jwt_token = request.state.jwt_token # Ephemeral restricted JWT ... """ @@ -57,7 +64,7 @@ def require_api_key(api_key_service): status_code=401, ) - # Validate the key + # Validate the key and get RBAC claims user_info = await api_key_service.validate_key(api_key) if not user_info: @@ -69,19 +76,46 @@ def require_api_key(api_key_service): status_code=401, ) - # Create a User object from the API key info + # Extract RBAC fields from API key + key_roles = user_info.get("roles", ["openrag_user"]) + key_groups = user_info.get("groups", []) + + # Create a User object with the API key's roles and groups user = User( user_id=user_info["user_id"], email=user_info["user_email"], name=user_info.get("name", "API User"), picture=None, provider="api_key", + roles=key_roles, + groups=key_groups, ) # Set request state request.state.user = user request.state.api_key_id = user_info["key_id"] - request.state.jwt_token = None # No JWT for API key auth + + # Generate ephemeral JWT with the API key's restricted roles/groups + if session_manager: + # Create a short-lived JWT with the key's specific permissions + ephemeral_jwt = session_manager.create_jwt_token( + user=user, + roles=key_roles, + groups=key_groups, + expiration_days=1, # Short-lived for API requests + ) + request.state.jwt_token = ephemeral_jwt + logger.debug( + "Generated ephemeral JWT for API key", + key_id=user_info["key_id"], + roles=key_roles, + groups=key_groups, + ) + else: + request.state.jwt_token = None + logger.warning( + "No session_manager provided - JWT not generated for API key" + ) return await handler(request) @@ -90,10 +124,13 @@ def require_api_key(api_key_service): return decorator -def optional_api_key(api_key_service): +def optional_api_key(api_key_service, session_manager=None): """ Decorator to optionally authenticate with API key. Sets request.state.user to None if no valid API key is provided. + + When a valid API key is provided, generates an ephemeral JWT with + the key's specific roles and groups. """ def decorator(handler): @@ -102,21 +139,38 @@ def optional_api_key(api_key_service): api_key = _extract_api_key(request) if api_key: - # Validate the key + # Validate the key and get RBAC claims user_info = await api_key_service.validate_key(api_key) if user_info: - # Create a User object from the API key info + # Extract RBAC fields from API key + key_roles = user_info.get("roles", ["openrag_user"]) + key_groups = user_info.get("groups", []) + + # Create a User object with the API key's roles and groups user = User( user_id=user_info["user_id"], email=user_info["user_email"], name=user_info.get("name", "API User"), picture=None, provider="api_key", + roles=key_roles, + groups=key_groups, ) request.state.user = user request.state.api_key_id = user_info["key_id"] - request.state.jwt_token = None + + # Generate ephemeral JWT with the API key's restricted roles/groups + if session_manager: + ephemeral_jwt = session_manager.create_jwt_token( + user=user, + roles=key_roles, + groups=key_groups, + expiration_days=1, + ) + request.state.jwt_token = ephemeral_jwt + else: + request.state.jwt_token = None else: request.state.user = None request.state.api_key_id = None diff --git a/src/auth_context.py b/src/auth_context.py index 0693087b..780f8dad 100644 --- a/src/auth_context.py +++ b/src/auth_context.py @@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls. """ from contextvars import ContextVar -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List # Context variables for current request authentication _current_user_id: ContextVar[Optional[str]] = ContextVar( @@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar( _current_jwt_token: ContextVar[Optional[str]] = ContextVar( "current_jwt_token", default=None ) +_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar( + "current_user_groups", default=None +) +_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar( + "current_user_roles", default=None +) _current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar( "current_search_filters", default=None ) @@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar( ) -def set_auth_context(user_id: str, jwt_token: str): - """Set authentication context for the current async context""" +def set_auth_context( + user_id: str, + jwt_token: str, + groups: Optional[List[str]] = None, + roles: Optional[List[str]] = None, +): + """Set authentication context for the current async context + + Args: + user_id: The user's ID + jwt_token: The JWT token for authentication + groups: Optional list of groups the user belongs to (for RBAC) + roles: Optional list of roles the user has (for RBAC) + """ _current_user_id.set(user_id) _current_jwt_token.set(jwt_token) + _current_user_groups.set(groups or []) + _current_user_roles.set(roles or []) def get_current_user_id() -> Optional[str]: @@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]: return _current_jwt_token.get() +def get_current_user_groups() -> List[str]: + """Get current user's groups from context (for RBAC)""" + return _current_user_groups.get() or [] + + +def get_current_user_roles() -> List[str]: + """Get current user's roles from context (for RBAC)""" + return _current_user_roles.get() or [] + + def get_auth_context() -> tuple[Optional[str], Optional[str]]: """Get current authentication context (user_id, jwt_token)""" return _current_user_id.get(), _current_jwt_token.get() diff --git a/src/main.py b/src/main.py index 710b3dab..7830dce8 100644 --- a/src/main.py +++ b/src/main.py @@ -1314,7 +1314,7 @@ async def create_app(): # Chat endpoints Route( "/v1/chat", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_chat.chat_create_endpoint, chat_service=services["chat_service"], @@ -1325,7 +1325,7 @@ async def create_app(): ), Route( "/v1/chat", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_chat.chat_list_endpoint, chat_service=services["chat_service"], @@ -1336,7 +1336,7 @@ async def create_app(): ), Route( "/v1/chat/{chat_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_chat.chat_get_endpoint, chat_service=services["chat_service"], @@ -1347,7 +1347,7 @@ async def create_app(): ), Route( "/v1/chat/{chat_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_chat.chat_delete_endpoint, chat_service=services["chat_service"], @@ -1359,7 +1359,7 @@ async def create_app(): # Search endpoint Route( "/v1/search", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_search.search_endpoint, search_service=services["search_service"], @@ -1371,7 +1371,7 @@ async def create_app(): # Documents endpoints Route( "/v1/documents/ingest", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_documents.ingest_endpoint, document_service=services["document_service"], @@ -1384,7 +1384,7 @@ async def create_app(): ), Route( "/v1/tasks/{task_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_documents.task_status_endpoint, task_service=services["task_service"], @@ -1395,7 +1395,7 @@ async def create_app(): ), Route( "/v1/documents", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_documents.delete_document_endpoint, document_service=services["document_service"], @@ -1407,14 +1407,14 @@ async def create_app(): # Settings endpoints Route( "/v1/settings", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial(v1_settings.get_settings_endpoint) ), methods=["GET"], ), Route( "/v1/settings", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_settings.update_settings_endpoint, session_manager=services["session_manager"], @@ -1425,7 +1425,7 @@ async def create_app(): # Knowledge filters endpoints Route( "/v1/knowledge-filters", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_knowledge_filters.create_endpoint, knowledge_filter_service=services["knowledge_filter_service"], @@ -1436,7 +1436,7 @@ async def create_app(): ), Route( "/v1/knowledge-filters/search", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_knowledge_filters.search_endpoint, knowledge_filter_service=services["knowledge_filter_service"], @@ -1447,7 +1447,7 @@ async def create_app(): ), Route( "/v1/knowledge-filters/{filter_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_knowledge_filters.get_endpoint, knowledge_filter_service=services["knowledge_filter_service"], @@ -1458,7 +1458,7 @@ async def create_app(): ), Route( "/v1/knowledge-filters/{filter_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_knowledge_filters.update_endpoint, knowledge_filter_service=services["knowledge_filter_service"], @@ -1469,7 +1469,7 @@ async def create_app(): ), Route( "/v1/knowledge-filters/{filter_id}", - require_api_key(services["api_key_service"])( + require_api_key(services["api_key_service"], services["session_manager"])( partial( v1_knowledge_filters.delete_endpoint, knowledge_filter_service=services["knowledge_filter_service"], diff --git a/src/models/processors.py b/src/models/processors.py index d8de30c5..766896b4 100644 --- a/src/models/processors.py +++ b/src/models/processors.py @@ -158,6 +158,7 @@ class TaskProcessor: connector_type: str = "local", embedding_model: str = None, is_sample_data: bool = False, + allowed_groups: list = None, ): """ Standard processing pipeline for non-Langflow processors: @@ -166,6 +167,8 @@ class TaskProcessor: Args: embedding_model: Embedding model to use (defaults to the current embedding model from settings) + allowed_groups: List of groups that can access this document (RBAC). + Empty list or None means no group restrictions. """ import datetime from config.settings import INDEX_NAME, clients, get_embedding_model @@ -259,6 +262,11 @@ class TaskProcessor: if owner_email is not None: chunk_doc["owner_email"] = owner_email + # RBAC: Set allowed groups for access control + # If allowed_groups is provided and non-empty, store it for DLS filtering + if allowed_groups: + chunk_doc["allowed_groups"] = allowed_groups + # Mark as sample data if specified if is_sample_data: chunk_doc["is_sample_data"] = "true" @@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor): owner_name: str = None, owner_email: str = None, is_sample_data: bool = False, + allowed_groups: list = None, ): super().__init__(document_service) self.owner_user_id = owner_user_id @@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor): self.owner_name = owner_name self.owner_email = owner_email self.is_sample_data = is_sample_data + self.allowed_groups = allowed_groups async def process_item( self, upload_task: UploadTask, item: str, file_task: FileTask @@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor): file_size=file_size, connector_type="local", is_sample_data=self.is_sample_data, + allowed_groups=self.allowed_groups, ) file_task.status = TaskStatus.COMPLETED @@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor): owner_name: str = None, owner_email: str = None, document_service=None, + allowed_groups: list = None, ): super().__init__(document_service=document_service) self.connector_service = connector_service @@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor): self.jwt_token = jwt_token self.owner_name = owner_name self.owner_email = owner_email + self.allowed_groups = allowed_groups async def process_item( self, upload_task: UploadTask, item: str, file_task: FileTask @@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor): owner_email=self.owner_email, file_size=len(document.content), connector_type=connection.connector_type, + allowed_groups=self.allowed_groups, ) # Add connector-specific metadata @@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor): jwt_token: str = None, owner_name: str = None, owner_email: str = None, + allowed_groups: list = None, ): super().__init__() self.langflow_connector_service = langflow_connector_service @@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor): self.jwt_token = jwt_token self.owner_name = owner_name self.owner_email = owner_email + self.allowed_groups = allowed_groups async def process_item( self, upload_task: UploadTask, item: str, file_task: FileTask @@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor): jwt_token: str = None, owner_name: str = None, owner_email: str = None, + allowed_groups: list = None, ): import boto3 @@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor): self.jwt_token = jwt_token self.owner_name = owner_name self.owner_email = owner_email + self.allowed_groups = allowed_groups async def process_item( self, upload_task: UploadTask, item: str, file_task: FileTask @@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor): owner_email=self.owner_email, file_size=file_size, connector_type="s3", + allowed_groups=self.allowed_groups, ) result["path"] = f"s3://{self.bucket}/{item}" @@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor): settings: dict = None, delete_after_ingest: bool = True, replace_duplicates: bool = False, + allowed_groups: list = None, ): super().__init__() self.langflow_file_service = langflow_file_service @@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor): self.settings = settings self.delete_after_ingest = delete_after_ingest self.replace_duplicates = replace_duplicates + self.allowed_groups = allowed_groups async def process_item( self, upload_task: UploadTask, item: str, file_task: FileTask @@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor): metadata_tweaks.append({"key": "owner_email", "value": self.owner_email}) # Mark as local upload for connector_type metadata_tweaks.append({"key": "connector_type", "value": "local"}) + # RBAC: Add allowed_groups for access control + if self.allowed_groups: + # Store as comma-separated string for Langflow metadata + metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)}) if metadata_tweaks: # Initialize the OpenSearch component tweaks if not already present diff --git a/src/services/api_key_service.py b/src/services/api_key_service.py index c519ecd7..c606e8ac 100644 --- a/src/services/api_key_service.py +++ b/src/services/api_key_service.py @@ -52,15 +52,19 @@ class APIKeyService: user_email: str, name: str, jwt_token: str = None, + roles: List[str] = None, + groups: List[str] = None, ) -> Dict[str, Any]: """ - Create a new API key for a user. + Create a new API key for a user with optional RBAC restrictions. Args: user_id: The user's ID user_email: The user's email name: A friendly name for the key jwt_token: JWT token for OpenSearch authentication + roles: Optional list of roles to restrict this key to + groups: Optional list of groups this key can access Returns: Dict with success status, key info, and the full key (only shown once) @@ -74,6 +78,14 @@ class APIKeyService: now = datetime.utcnow().isoformat() + # Default roles if not specified + if roles is None: + roles = ["openrag_user"] + + # Default groups to empty if not specified + if groups is None: + groups = [] + # Create the document to store key_doc = { "key_id": key_id, @@ -85,6 +97,9 @@ class APIKeyService: "created_at": now, "last_used_at": None, "revoked": False, + # RBAC fields + "roles": roles, + "groups": groups, } # Get OpenSearch client @@ -105,6 +120,8 @@ class APIKeyService: user_id=user_id, key_id=key_id, key_prefix=key_prefix, + roles=roles, + groups=groups, ) return { "success": True, @@ -112,6 +129,8 @@ class APIKeyService: "key_prefix": key_prefix, "name": name, "created_at": now, + "roles": roles, + "groups": groups, "api_key": full_key, # Only returned once! } else: @@ -123,13 +142,13 @@ class APIKeyService: async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]: """ - Validate an API key and return user info if valid. + Validate an API key and return user info with RBAC claims if valid. Args: api_key: The full API key to validate Returns: - Dict with user info if valid, None if invalid + Dict with user info including roles and groups if valid, None if invalid """ try: # Check key format @@ -181,11 +200,15 @@ class APIKeyService: except Exception: pass # Don't fail validation if update fails + # Return user info with RBAC claims return { "key_id": key_doc["key_id"], "user_id": key_doc["user_id"], "user_email": key_doc["user_email"], "name": key_doc["name"], + # RBAC fields - provide defaults for backward compatibility + "roles": key_doc.get("roles", ["openrag_user"]), + "groups": key_doc.get("groups", []), } except Exception as e: @@ -225,6 +248,8 @@ class APIKeyService: "created_at", "last_used_at", "revoked", + "roles", + "groups", ], "size": 100, } diff --git a/src/services/chat_service.py b/src/services/chat_service.py index 92c834a8..bcecfc4f 100644 --- a/src/services/chat_service.py +++ b/src/services/chat_service.py @@ -16,6 +16,8 @@ class ChatService: previous_response_id: str = None, stream: bool = False, filter_id: str = None, + groups: list = None, + roles: list = None, ): """Handle chat requests using the patched OpenAI client""" if not prompt: @@ -23,7 +25,7 @@ class ChatService: # Set authentication context for this request so tools can access it if user_id and jwt_token: - set_auth_context(user_id, jwt_token) + set_auth_context(user_id, jwt_token, groups=groups, roles=roles) if stream: return async_chat_stream( @@ -54,6 +56,8 @@ class ChatService: previous_response_id: str = None, stream: bool = False, filter_id: str = None, + groups: list = None, + roles: list = None, ): """Handle Langflow chat requests""" if not prompt: @@ -346,9 +350,9 @@ class ChatService: previous_response_id=previous_response_id, ) else: # chat - # Set auth context for chat tools and provide user_id + # Set auth context for chat tools and provide user_id with RBAC if user_id and jwt_token: - set_auth_context(user_id, jwt_token) + set_auth_context(user_id, jwt_token, groups=groups, roles=roles) response_text, response_id = await async_chat( clients.patched_llm_client, document_prompt, diff --git a/src/services/search_service.py b/src/services/search_service.py index b0927d0f..b1492ef3 100644 --- a/src/services/search_service.py +++ b/src/services/search_service.py @@ -2,7 +2,7 @@ import copy from typing import Any, Dict from agentd.tool_decorator import tool from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS -from auth_context import get_auth_context +from auth_context import get_auth_context, get_current_user_groups from utils.logging_config import get_logger logger = get_logger(__name__) @@ -259,8 +259,24 @@ class SearchService: # Build query body if is_wildcard_match_all: # Match all documents; still allow filters to narrow scope - if filter_clauses: - query_block = {"bool": {"filter": filter_clauses}} + # Also add RBAC group filter for wildcard queries + wildcard_filters = list(filter_clauses) # Copy existing filters + + user_groups = get_current_user_groups() + if user_groups: + groups_access_filter = { + "bool": { + "should": [ + {"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}}, + {"terms": {"allowed_groups": user_groups}}, + ], + "minimum_should_match": 1 + } + } + wildcard_filters.append(groups_access_filter) + + if wildcard_filters: + query_block = {"bool": {"filter": wildcard_filters}} else: query_block = {"match_all": {}} else: @@ -292,6 +308,29 @@ class SearchService: # Add exists filter to existing filters all_filters = [*filter_clauses, exists_any_embedding] + # RBAC: Add group-based access control filter (fallback if DLS isn't configured) + # Documents are accessible if: + # 1. No allowed_groups field exists (backward compatibility, open access) + # 2. User's groups match any of the document's allowed_groups + user_groups = get_current_user_groups() + if user_groups: + groups_access_filter = { + "bool": { + "should": [ + # Document has no allowed_groups restriction + {"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}}, + # User's groups match document's allowed_groups + {"terms": {"allowed_groups": user_groups}}, + ], + "minimum_should_match": 1 + } + } + all_filters.append(groups_access_filter) + logger.debug( + "Added RBAC group filter", + user_groups=user_groups, + ) + logger.debug( "Building hybrid query with filters", user_filters_count=len(filter_clauses), diff --git a/src/session_manager.py b/src/session_manager.py index 6b2023d5..14eb23d6 100644 --- a/src/session_manager.py +++ b/src/session_manager.py @@ -2,8 +2,8 @@ import json import jwt import httpx from datetime import datetime, timedelta -from typing import Dict, Optional, Any -from dataclasses import dataclass, asdict +from typing import Dict, Optional, Any, List +from dataclasses import dataclass, field from cryptography.hazmat.primitives import serialization import os from utils.logging_config import get_logger @@ -24,12 +24,19 @@ class User: provider: str = "google" created_at: datetime = None last_login: datetime = None + # RBAC fields + roles: List[str] = field(default_factory=lambda: ["openrag_user"]) + groups: List[str] = field(default_factory=list) def __post_init__(self): if self.created_at is None: self.created_at = datetime.now() if self.last_login is None: self.last_login = datetime.now() + if self.roles is None: + self.roles = ["openrag_user"] + if self.groups is None: + self.groups = [] class AnonymousUser(User): """Anonymous user""" @@ -136,11 +143,31 @@ class SessionManager: # Create JWT token using the shared method return self.create_jwt_token(user) - def create_jwt_token(self, user: User) -> str: - """Create JWT token for an existing user""" + def create_jwt_token( + self, + user: User, + roles: Optional[List[str]] = None, + groups: Optional[List[str]] = None, + expiration_days: int = 7, + ) -> str: + """Create JWT token for an existing user. + + Args: + user: The User object to create a token for + roles: Optional roles override (for restricted API key tokens) + groups: Optional groups override (for restricted API key tokens) + expiration_days: Token expiration in days (default 7) + + Returns: + Encoded JWT token string + """ # Use OpenSearch-compatible issuer for OIDC validation oidc_issuer = "http://openrag-backend:8000" + # Use provided roles/groups or fall back to user's defaults + effective_roles = roles if roles is not None else user.roles + effective_groups = groups if groups is not None else user.groups + # Create JWT token with OIDC-compliant claims now = datetime.utcnow() token_payload = { @@ -148,7 +175,7 @@ class SessionManager: "iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC "sub": user.user_id, # Subject (user ID) "aud": ["opensearch", "openrag"], # Audience - "exp": now + timedelta(days=7), # Expiration + "exp": now + timedelta(days=expiration_days), # Expiration "iat": now, # Issued at "auth_time": int(now.timestamp()), # Authentication time # Custom claims @@ -157,7 +184,9 @@ class SessionManager: "name": user.name, "preferred_username": user.email, "email_verified": True, - "roles": ["openrag_user"], # Backend role for OpenSearch + # RBAC claims + "roles": effective_roles, # Backend roles for OpenSearch/tools + "groups": effective_groups, # Group-based access control } token = jwt.encode(token_payload, self.private_key, algorithm="RS256")