RBAC basic implementation
This commit is contained in:
parent
c923ecb396
commit
ce51628db2
14 changed files with 409 additions and 98 deletions
|
|
@ -6,6 +6,8 @@ import {
|
||||||
|
|
||||||
export interface CreateApiKeyRequest {
|
export interface CreateApiKeyRequest {
|
||||||
name: string;
|
name: string;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CreateApiKeyResponse {
|
export interface CreateApiKeyResponse {
|
||||||
|
|
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
|
||||||
name: string;
|
name: string;
|
||||||
key_prefix: string;
|
key_prefix: string;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useCreateApiKeyMutation = (
|
export const useCreateApiKeyMutation = (
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@ export interface ApiKey {
|
||||||
key_prefix: string;
|
key_prefix: string;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
last_used_at: string | null;
|
last_used_at: string | null;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GetApiKeysResponse {
|
export interface GetApiKeysResponse {
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ function KnowledgeSourcesPage() {
|
||||||
// API Keys state
|
// API Keys state
|
||||||
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
|
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
|
||||||
const [newKeyName, setNewKeyName] = useState("");
|
const [newKeyName, setNewKeyName] = useState("");
|
||||||
|
const [newKeyGroups, setNewKeyGroups] = useState("");
|
||||||
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
|
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
|
||||||
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
|
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
|
||||||
|
|
||||||
|
|
@ -156,6 +157,7 @@ function KnowledgeSourcesPage() {
|
||||||
setCreateKeyDialogOpen(false);
|
setCreateKeyDialogOpen(false);
|
||||||
setShowKeyDialogOpen(true);
|
setShowKeyDialogOpen(true);
|
||||||
setNewKeyName("");
|
setNewKeyName("");
|
||||||
|
setNewKeyGroups("");
|
||||||
toast.success("API key created");
|
toast.success("API key created");
|
||||||
},
|
},
|
||||||
onError: (error) => {
|
onError: (error) => {
|
||||||
|
|
@ -438,7 +440,16 @@ function KnowledgeSourcesPage() {
|
||||||
toast.error("Please enter a name for the API key");
|
toast.error("Please enter a name for the API key");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
createApiKeyMutation.mutate({ name: newKeyName.trim() });
|
// Parse groups from comma-separated string
|
||||||
|
const groups = newKeyGroups
|
||||||
|
.split(",")
|
||||||
|
.map((g) => g.trim())
|
||||||
|
.filter((g) => g.length > 0);
|
||||||
|
|
||||||
|
createApiKeyMutation.mutate({
|
||||||
|
name: newKeyName.trim(),
|
||||||
|
groups: groups.length > 0 ? groups : undefined,
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleRevokeApiKey = (keyId: string) => {
|
const handleRevokeApiKey = (keyId: string) => {
|
||||||
|
|
@ -1426,6 +1437,9 @@ function KnowledgeSourcesPage() {
|
||||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
Key
|
Key
|
||||||
</th>
|
</th>
|
||||||
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
|
Groups
|
||||||
|
</th>
|
||||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
Created
|
Created
|
||||||
</th>
|
</th>
|
||||||
|
|
@ -1448,6 +1462,24 @@ function KnowledgeSourcesPage() {
|
||||||
{key.key_prefix}...
|
{key.key_prefix}...
|
||||||
</code>
|
</code>
|
||||||
</td>
|
</td>
|
||||||
|
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||||
|
{key.groups && key.groups.length > 0 ? (
|
||||||
|
<div className="flex flex-wrap gap-1">
|
||||||
|
{key.groups.map((group: string) => (
|
||||||
|
<span
|
||||||
|
key={group}
|
||||||
|
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
|
||||||
|
>
|
||||||
|
{group}
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<span className="text-muted-foreground/50">
|
||||||
|
All groups
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
<td className="px-4 py-3 text-sm text-muted-foreground">
|
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||||
{formatDate(key.created_at)}
|
{formatDate(key.created_at)}
|
||||||
</td>
|
</td>
|
||||||
|
|
@ -1507,58 +1539,77 @@ function KnowledgeSourcesPage() {
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Create API Key Dialog */}
|
{/* Create API Key Dialog */}
|
||||||
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
||||||
<DialogContent>
|
<DialogContent>
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<DialogTitle>Create API Key</DialogTitle>
|
<DialogTitle>Create API Key</DialogTitle>
|
||||||
<DialogDescription>
|
<DialogDescription>
|
||||||
Give your API key a name to help you identify it later.
|
Create an API key with optional group restrictions for access
|
||||||
</DialogDescription>
|
control.
|
||||||
</DialogHeader>
|
</DialogDescription>
|
||||||
<div className="py-4">
|
</DialogHeader>
|
||||||
<LabelWrapper label="Name" id="api-key-name">
|
<div className="py-4 space-y-4">
|
||||||
<Input
|
<LabelWrapper label="Name" id="api-key-name">
|
||||||
id="api-key-name"
|
<Input
|
||||||
placeholder="e.g., Production App, Development"
|
id="api-key-name"
|
||||||
value={newKeyName}
|
placeholder="e.g., Production App, Development"
|
||||||
onChange={(e) => setNewKeyName(e.target.value)}
|
value={newKeyName}
|
||||||
onKeyDown={(e) => {
|
onChange={(e) => setNewKeyName(e.target.value)}
|
||||||
if (e.key === "Enter") {
|
onKeyDown={(e) => {
|
||||||
handleCreateApiKey();
|
if (e.key === "Enter") {
|
||||||
}
|
handleCreateApiKey();
|
||||||
}}
|
}
|
||||||
/>
|
|
||||||
</LabelWrapper>
|
|
||||||
</div>
|
|
||||||
<DialogFooter>
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
onClick={() => {
|
|
||||||
setCreateKeyDialogOpen(false);
|
|
||||||
setNewKeyName("");
|
|
||||||
}}
|
}}
|
||||||
size="sm"
|
/>
|
||||||
>
|
</LabelWrapper>
|
||||||
Cancel
|
<LabelWrapper
|
||||||
</Button>
|
label="Groups (optional)"
|
||||||
<Button
|
id="api-key-groups"
|
||||||
onClick={handleCreateApiKey}
|
helperText="Comma-separated list of groups this key can access"
|
||||||
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
>
|
||||||
size="sm"
|
<Input
|
||||||
>
|
id="api-key-groups"
|
||||||
{createApiKeyMutation.isPending ? (
|
placeholder="e.g., finance, hr, engineering"
|
||||||
<>
|
value={newKeyGroups}
|
||||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
onChange={(e) => setNewKeyGroups(e.target.value)}
|
||||||
Creating...
|
onKeyDown={(e) => {
|
||||||
</>
|
if (e.key === "Enter") {
|
||||||
) : (
|
handleCreateApiKey();
|
||||||
"Create Key"
|
}
|
||||||
)}
|
}}
|
||||||
</Button>
|
/>
|
||||||
</DialogFooter>
|
</LabelWrapper>
|
||||||
</DialogContent>
|
</div>
|
||||||
</Dialog>
|
<DialogFooter>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
onClick={() => {
|
||||||
|
setCreateKeyDialogOpen(false);
|
||||||
|
setNewKeyName("");
|
||||||
|
setNewKeyGroups("");
|
||||||
|
}}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
onClick={handleCreateApiKey}
|
||||||
|
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
{createApiKeyMutation.isPending ? (
|
||||||
|
<>
|
||||||
|
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||||
|
Creating...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
"Create Key"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
{/* Show Created API Key Dialog */}
|
{/* Show Created API Key Dialog */}
|
||||||
<Dialog
|
<Dialog
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
set_search_limit(limit)
|
set_search_limit(limit)
|
||||||
set_score_threshold(score_threshold)
|
set_score_threshold(score_threshold)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
await chat_service.chat(
|
await chat_service.chat(
|
||||||
|
|
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||||
|
|
||||||
|
|
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
|
||||||
|
|
||||||
async def create_key_endpoint(request: Request, api_key_service):
|
async def create_key_endpoint(request: Request, api_key_service):
|
||||||
"""
|
"""
|
||||||
Create a new API key for the authenticated user.
|
Create a new API key for the authenticated user with optional RBAC restrictions.
|
||||||
|
|
||||||
POST /keys
|
POST /keys
|
||||||
Body: {"name": "My API Key"}
|
Body: {
|
||||||
|
"name": "My API Key",
|
||||||
|
"roles": ["openrag_user"], // Optional: restrict key to specific roles
|
||||||
|
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
|
||||||
|
}
|
||||||
|
|
||||||
Response:
|
Response:
|
||||||
{
|
{
|
||||||
|
|
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
|
||||||
"key_id": "...",
|
"key_id": "...",
|
||||||
"key_prefix": "orag_abc12345",
|
"key_prefix": "orag_abc12345",
|
||||||
"name": "My API Key",
|
"name": "My API Key",
|
||||||
|
"roles": ["openrag_user"],
|
||||||
|
"groups": ["finance", "hr"],
|
||||||
"created_at": "2024-01-01T00:00:00",
|
"created_at": "2024-01-01T00:00:00",
|
||||||
"api_key": "orag_abc12345..." // Full key, only shown once!
|
"api_key": "orag_abc12345..." // Full key, only shown once!
|
||||||
}
|
}
|
||||||
|
|
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract optional RBAC fields
|
||||||
|
roles = data.get("roles")
|
||||||
|
groups = data.get("groups")
|
||||||
|
|
||||||
|
# Validate roles and groups are lists if provided
|
||||||
|
if roles is not None and not isinstance(roles, list):
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "roles must be a list"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if groups is not None and not isinstance(groups, list):
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "groups must be a list"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
result = await api_key_service.create_key(
|
result = await api_key_service.create_key(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_email=user_email,
|
user_email=user_email,
|
||||||
name=name,
|
name=name,
|
||||||
jwt_token=jwt_token,
|
jwt_token=jwt_token,
|
||||||
|
roles=roles,
|
||||||
|
groups=groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
|
|
|
||||||
|
|
@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
user_id = user.user_id
|
user_id = user.user_id
|
||||||
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
|
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
# Set context variables for search tool
|
# Set context variables for search tool
|
||||||
if filters:
|
if filters:
|
||||||
set_search_filters(filters)
|
set_search_filters(filters)
|
||||||
set_search_limit(limit)
|
set_search_limit(limit)
|
||||||
set_score_threshold(score_threshold)
|
set_score_threshold(score_threshold)
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
raw_stream = await chat_service.langflow_chat(
|
raw_stream = await chat_service.langflow_chat(
|
||||||
|
|
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=chat_id,
|
previous_response_id=chat_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
chat_id_container = {}
|
chat_id_container = {}
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=chat_id,
|
previous_response_id=chat_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
# Transform response_id to chat_id for v1 API format
|
# Transform response_id to chat_id for v1 API format
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
"""
|
"""
|
||||||
API Key middleware for authenticating public API requests.
|
API Key middleware for authenticating public API requests.
|
||||||
|
|
||||||
|
This middleware validates API keys and generates ephemeral JWTs with the
|
||||||
|
key's specific roles and groups for downstream security enforcement.
|
||||||
"""
|
"""
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def require_api_key(api_key_service):
|
def require_api_key(api_key_service, session_manager=None):
|
||||||
"""
|
"""
|
||||||
Decorator to require API key authentication for public API endpoints.
|
Decorator to require API key authentication for public API endpoints.
|
||||||
|
|
||||||
|
Generates an ephemeral JWT with the API key's specific roles and groups
|
||||||
|
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@require_api_key(api_key_service)
|
@require_api_key(api_key_service, session_manager)
|
||||||
async def my_endpoint(request):
|
async def my_endpoint(request):
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
|
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate the key
|
# Validate the key and get RBAC claims
|
||||||
user_info = await api_key_service.validate_key(api_key)
|
user_info = await api_key_service.validate_key(api_key)
|
||||||
|
|
||||||
if not user_info:
|
if not user_info:
|
||||||
|
|
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a User object from the API key info
|
# Extract RBAC fields from API key
|
||||||
|
key_roles = user_info.get("roles", ["openrag_user"])
|
||||||
|
key_groups = user_info.get("groups", [])
|
||||||
|
|
||||||
|
# Create a User object with the API key's roles and groups
|
||||||
user = User(
|
user = User(
|
||||||
user_id=user_info["user_id"],
|
user_id=user_info["user_id"],
|
||||||
email=user_info["user_email"],
|
email=user_info["user_email"],
|
||||||
name=user_info.get("name", "API User"),
|
name=user_info.get("name", "API User"),
|
||||||
picture=None,
|
picture=None,
|
||||||
provider="api_key",
|
provider="api_key",
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set request state
|
# Set request state
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
request.state.api_key_id = user_info["key_id"]
|
request.state.api_key_id = user_info["key_id"]
|
||||||
request.state.jwt_token = None # No JWT for API key auth
|
|
||||||
|
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||||
|
if session_manager:
|
||||||
|
# Create a short-lived JWT with the key's specific permissions
|
||||||
|
ephemeral_jwt = session_manager.create_jwt_token(
|
||||||
|
user=user,
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
expiration_days=1, # Short-lived for API requests
|
||||||
|
)
|
||||||
|
request.state.jwt_token = ephemeral_jwt
|
||||||
|
logger.debug(
|
||||||
|
"Generated ephemeral JWT for API key",
|
||||||
|
key_id=user_info["key_id"],
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request.state.jwt_token = None
|
||||||
|
logger.warning(
|
||||||
|
"No session_manager provided - JWT not generated for API key"
|
||||||
|
)
|
||||||
|
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
|
|
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def optional_api_key(api_key_service):
|
def optional_api_key(api_key_service, session_manager=None):
|
||||||
"""
|
"""
|
||||||
Decorator to optionally authenticate with API key.
|
Decorator to optionally authenticate with API key.
|
||||||
Sets request.state.user to None if no valid API key is provided.
|
Sets request.state.user to None if no valid API key is provided.
|
||||||
|
|
||||||
|
When a valid API key is provided, generates an ephemeral JWT with
|
||||||
|
the key's specific roles and groups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(handler):
|
def decorator(handler):
|
||||||
|
|
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
|
||||||
api_key = _extract_api_key(request)
|
api_key = _extract_api_key(request)
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# Validate the key
|
# Validate the key and get RBAC claims
|
||||||
user_info = await api_key_service.validate_key(api_key)
|
user_info = await api_key_service.validate_key(api_key)
|
||||||
|
|
||||||
if user_info:
|
if user_info:
|
||||||
# Create a User object from the API key info
|
# Extract RBAC fields from API key
|
||||||
|
key_roles = user_info.get("roles", ["openrag_user"])
|
||||||
|
key_groups = user_info.get("groups", [])
|
||||||
|
|
||||||
|
# Create a User object with the API key's roles and groups
|
||||||
user = User(
|
user = User(
|
||||||
user_id=user_info["user_id"],
|
user_id=user_info["user_id"],
|
||||||
email=user_info["user_email"],
|
email=user_info["user_email"],
|
||||||
name=user_info.get("name", "API User"),
|
name=user_info.get("name", "API User"),
|
||||||
picture=None,
|
picture=None,
|
||||||
provider="api_key",
|
provider="api_key",
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
)
|
)
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
request.state.api_key_id = user_info["key_id"]
|
request.state.api_key_id = user_info["key_id"]
|
||||||
request.state.jwt_token = None
|
|
||||||
|
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||||
|
if session_manager:
|
||||||
|
ephemeral_jwt = session_manager.create_jwt_token(
|
||||||
|
user=user,
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
expiration_days=1,
|
||||||
|
)
|
||||||
|
request.state.jwt_token = ephemeral_jwt
|
||||||
|
else:
|
||||||
|
request.state.jwt_token = None
|
||||||
else:
|
else:
|
||||||
request.state.user = None
|
request.state.user = None
|
||||||
request.state.api_key_id = None
|
request.state.api_key_id = None
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any, List
|
||||||
|
|
||||||
# Context variables for current request authentication
|
# Context variables for current request authentication
|
||||||
_current_user_id: ContextVar[Optional[str]] = ContextVar(
|
_current_user_id: ContextVar[Optional[str]] = ContextVar(
|
||||||
|
|
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
|
||||||
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
|
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
|
||||||
"current_jwt_token", default=None
|
"current_jwt_token", default=None
|
||||||
)
|
)
|
||||||
|
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
|
||||||
|
"current_user_groups", default=None
|
||||||
|
)
|
||||||
|
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
|
||||||
|
"current_user_roles", default=None
|
||||||
|
)
|
||||||
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
||||||
"current_search_filters", default=None
|
"current_search_filters", default=None
|
||||||
)
|
)
|
||||||
|
|
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_auth_context(user_id: str, jwt_token: str):
|
def set_auth_context(
|
||||||
"""Set authentication context for the current async context"""
|
user_id: str,
|
||||||
|
jwt_token: str,
|
||||||
|
groups: Optional[List[str]] = None,
|
||||||
|
roles: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Set authentication context for the current async context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
jwt_token: The JWT token for authentication
|
||||||
|
groups: Optional list of groups the user belongs to (for RBAC)
|
||||||
|
roles: Optional list of roles the user has (for RBAC)
|
||||||
|
"""
|
||||||
_current_user_id.set(user_id)
|
_current_user_id.set(user_id)
|
||||||
_current_jwt_token.set(jwt_token)
|
_current_jwt_token.set(jwt_token)
|
||||||
|
_current_user_groups.set(groups or [])
|
||||||
|
_current_user_roles.set(roles or [])
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_id() -> Optional[str]:
|
def get_current_user_id() -> Optional[str]:
|
||||||
|
|
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
|
||||||
return _current_jwt_token.get()
|
return _current_jwt_token.get()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_groups() -> List[str]:
|
||||||
|
"""Get current user's groups from context (for RBAC)"""
|
||||||
|
return _current_user_groups.get() or []
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_roles() -> List[str]:
|
||||||
|
"""Get current user's roles from context (for RBAC)"""
|
||||||
|
return _current_user_roles.get() or []
|
||||||
|
|
||||||
|
|
||||||
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
||||||
"""Get current authentication context (user_id, jwt_token)"""
|
"""Get current authentication context (user_id, jwt_token)"""
|
||||||
return _current_user_id.get(), _current_jwt_token.get()
|
return _current_user_id.get(), _current_jwt_token.get()
|
||||||
|
|
|
||||||
30
src/main.py
30
src/main.py
|
|
@ -1314,7 +1314,7 @@ async def create_app():
|
||||||
# Chat endpoints
|
# Chat endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat",
|
"/v1/chat",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_create_endpoint,
|
v1_chat.chat_create_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1325,7 +1325,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat",
|
"/v1/chat",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_list_endpoint,
|
v1_chat.chat_list_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1336,7 +1336,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat/{chat_id}",
|
"/v1/chat/{chat_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_get_endpoint,
|
v1_chat.chat_get_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1347,7 +1347,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat/{chat_id}",
|
"/v1/chat/{chat_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_delete_endpoint,
|
v1_chat.chat_delete_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1359,7 +1359,7 @@ async def create_app():
|
||||||
# Search endpoint
|
# Search endpoint
|
||||||
Route(
|
Route(
|
||||||
"/v1/search",
|
"/v1/search",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_search.search_endpoint,
|
v1_search.search_endpoint,
|
||||||
search_service=services["search_service"],
|
search_service=services["search_service"],
|
||||||
|
|
@ -1371,7 +1371,7 @@ async def create_app():
|
||||||
# Documents endpoints
|
# Documents endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/documents/ingest",
|
"/v1/documents/ingest",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.ingest_endpoint,
|
v1_documents.ingest_endpoint,
|
||||||
document_service=services["document_service"],
|
document_service=services["document_service"],
|
||||||
|
|
@ -1384,7 +1384,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/tasks/{task_id}",
|
"/v1/tasks/{task_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.task_status_endpoint,
|
v1_documents.task_status_endpoint,
|
||||||
task_service=services["task_service"],
|
task_service=services["task_service"],
|
||||||
|
|
@ -1395,7 +1395,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/documents",
|
"/v1/documents",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.delete_document_endpoint,
|
v1_documents.delete_document_endpoint,
|
||||||
document_service=services["document_service"],
|
document_service=services["document_service"],
|
||||||
|
|
@ -1407,14 +1407,14 @@ async def create_app():
|
||||||
# Settings endpoints
|
# Settings endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/settings",
|
"/v1/settings",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(v1_settings.get_settings_endpoint)
|
partial(v1_settings.get_settings_endpoint)
|
||||||
),
|
),
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/settings",
|
"/v1/settings",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_settings.update_settings_endpoint,
|
v1_settings.update_settings_endpoint,
|
||||||
session_manager=services["session_manager"],
|
session_manager=services["session_manager"],
|
||||||
|
|
@ -1425,7 +1425,7 @@ async def create_app():
|
||||||
# Knowledge filters endpoints
|
# Knowledge filters endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters",
|
"/v1/knowledge-filters",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.create_endpoint,
|
v1_knowledge_filters.create_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1436,7 +1436,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/search",
|
"/v1/knowledge-filters/search",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.search_endpoint,
|
v1_knowledge_filters.search_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1447,7 +1447,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.get_endpoint,
|
v1_knowledge_filters.get_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1458,7 +1458,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.update_endpoint,
|
v1_knowledge_filters.update_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1469,7 +1469,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.delete_endpoint,
|
v1_knowledge_filters.delete_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,7 @@ class TaskProcessor:
|
||||||
connector_type: str = "local",
|
connector_type: str = "local",
|
||||||
embedding_model: str = None,
|
embedding_model: str = None,
|
||||||
is_sample_data: bool = False,
|
is_sample_data: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Standard processing pipeline for non-Langflow processors:
|
Standard processing pipeline for non-Langflow processors:
|
||||||
|
|
@ -166,6 +167,8 @@ class TaskProcessor:
|
||||||
Args:
|
Args:
|
||||||
embedding_model: Embedding model to use (defaults to the current
|
embedding_model: Embedding model to use (defaults to the current
|
||||||
embedding model from settings)
|
embedding model from settings)
|
||||||
|
allowed_groups: List of groups that can access this document (RBAC).
|
||||||
|
Empty list or None means no group restrictions.
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||||
|
|
@ -259,6 +262,11 @@ class TaskProcessor:
|
||||||
if owner_email is not None:
|
if owner_email is not None:
|
||||||
chunk_doc["owner_email"] = owner_email
|
chunk_doc["owner_email"] = owner_email
|
||||||
|
|
||||||
|
# RBAC: Set allowed groups for access control
|
||||||
|
# If allowed_groups is provided and non-empty, store it for DLS filtering
|
||||||
|
if allowed_groups:
|
||||||
|
chunk_doc["allowed_groups"] = allowed_groups
|
||||||
|
|
||||||
# Mark as sample data if specified
|
# Mark as sample data if specified
|
||||||
if is_sample_data:
|
if is_sample_data:
|
||||||
chunk_doc["is_sample_data"] = "true"
|
chunk_doc["is_sample_data"] = "true"
|
||||||
|
|
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
is_sample_data: bool = False,
|
is_sample_data: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__(document_service)
|
super().__init__(document_service)
|
||||||
self.owner_user_id = owner_user_id
|
self.owner_user_id = owner_user_id
|
||||||
|
|
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
self.is_sample_data = is_sample_data
|
self.is_sample_data = is_sample_data
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
connector_type="local",
|
connector_type="local",
|
||||||
is_sample_data=self.is_sample_data,
|
is_sample_data=self.is_sample_data,
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_task.status = TaskStatus.COMPLETED
|
file_task.status = TaskStatus.COMPLETED
|
||||||
|
|
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
document_service=None,
|
document_service=None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__(document_service=document_service)
|
super().__init__(document_service=document_service)
|
||||||
self.connector_service = connector_service
|
self.connector_service = connector_service
|
||||||
|
|
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
owner_email=self.owner_email,
|
owner_email=self.owner_email,
|
||||||
file_size=len(document.content),
|
file_size=len(document.content),
|
||||||
connector_type=connection.connector_type,
|
connector_type=connection.connector_type,
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add connector-specific metadata
|
# Add connector-specific metadata
|
||||||
|
|
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.langflow_connector_service = langflow_connector_service
|
self.langflow_connector_service = langflow_connector_service
|
||||||
|
|
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
|
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
owner_email=self.owner_email,
|
owner_email=self.owner_email,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
connector_type="s3",
|
connector_type="s3",
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
result["path"] = f"s3://{self.bucket}/{item}"
|
result["path"] = f"s3://{self.bucket}/{item}"
|
||||||
|
|
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
settings: dict = None,
|
settings: dict = None,
|
||||||
delete_after_ingest: bool = True,
|
delete_after_ingest: bool = True,
|
||||||
replace_duplicates: bool = False,
|
replace_duplicates: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.langflow_file_service = langflow_file_service
|
self.langflow_file_service = langflow_file_service
|
||||||
|
|
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.delete_after_ingest = delete_after_ingest
|
self.delete_after_ingest = delete_after_ingest
|
||||||
self.replace_duplicates = replace_duplicates
|
self.replace_duplicates = replace_duplicates
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
|
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
|
||||||
# Mark as local upload for connector_type
|
# Mark as local upload for connector_type
|
||||||
metadata_tweaks.append({"key": "connector_type", "value": "local"})
|
metadata_tweaks.append({"key": "connector_type", "value": "local"})
|
||||||
|
# RBAC: Add allowed_groups for access control
|
||||||
|
if self.allowed_groups:
|
||||||
|
# Store as comma-separated string for Langflow metadata
|
||||||
|
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
|
||||||
|
|
||||||
if metadata_tweaks:
|
if metadata_tweaks:
|
||||||
# Initialize the OpenSearch component tweaks if not already present
|
# Initialize the OpenSearch component tweaks if not already present
|
||||||
|
|
|
||||||
|
|
@ -52,15 +52,19 @@ class APIKeyService:
|
||||||
user_email: str,
|
user_email: str,
|
||||||
name: str,
|
name: str,
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
|
roles: List[str] = None,
|
||||||
|
groups: List[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create a new API key for a user.
|
Create a new API key for a user with optional RBAC restrictions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
user_email: The user's email
|
user_email: The user's email
|
||||||
name: A friendly name for the key
|
name: A friendly name for the key
|
||||||
jwt_token: JWT token for OpenSearch authentication
|
jwt_token: JWT token for OpenSearch authentication
|
||||||
|
roles: Optional list of roles to restrict this key to
|
||||||
|
groups: Optional list of groups this key can access
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with success status, key info, and the full key (only shown once)
|
Dict with success status, key info, and the full key (only shown once)
|
||||||
|
|
@ -74,6 +78,14 @@ class APIKeyService:
|
||||||
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
# Default roles if not specified
|
||||||
|
if roles is None:
|
||||||
|
roles = ["openrag_user"]
|
||||||
|
|
||||||
|
# Default groups to empty if not specified
|
||||||
|
if groups is None:
|
||||||
|
groups = []
|
||||||
|
|
||||||
# Create the document to store
|
# Create the document to store
|
||||||
key_doc = {
|
key_doc = {
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
|
|
@ -85,6 +97,9 @@ class APIKeyService:
|
||||||
"created_at": now,
|
"created_at": now,
|
||||||
"last_used_at": None,
|
"last_used_at": None,
|
||||||
"revoked": False,
|
"revoked": False,
|
||||||
|
# RBAC fields
|
||||||
|
"roles": roles,
|
||||||
|
"groups": groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get OpenSearch client
|
# Get OpenSearch client
|
||||||
|
|
@ -105,6 +120,8 @@ class APIKeyService:
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
key_id=key_id,
|
key_id=key_id,
|
||||||
key_prefix=key_prefix,
|
key_prefix=key_prefix,
|
||||||
|
roles=roles,
|
||||||
|
groups=groups,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -112,6 +129,8 @@ class APIKeyService:
|
||||||
"key_prefix": key_prefix,
|
"key_prefix": key_prefix,
|
||||||
"name": name,
|
"name": name,
|
||||||
"created_at": now,
|
"created_at": now,
|
||||||
|
"roles": roles,
|
||||||
|
"groups": groups,
|
||||||
"api_key": full_key, # Only returned once!
|
"api_key": full_key, # Only returned once!
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -123,13 +142,13 @@ class APIKeyService:
|
||||||
|
|
||||||
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Validate an API key and return user info if valid.
|
Validate an API key and return user info with RBAC claims if valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: The full API key to validate
|
api_key: The full API key to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with user info if valid, None if invalid
|
Dict with user info including roles and groups if valid, None if invalid
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Check key format
|
# Check key format
|
||||||
|
|
@ -181,11 +200,15 @@ class APIKeyService:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Don't fail validation if update fails
|
pass # Don't fail validation if update fails
|
||||||
|
|
||||||
|
# Return user info with RBAC claims
|
||||||
return {
|
return {
|
||||||
"key_id": key_doc["key_id"],
|
"key_id": key_doc["key_id"],
|
||||||
"user_id": key_doc["user_id"],
|
"user_id": key_doc["user_id"],
|
||||||
"user_email": key_doc["user_email"],
|
"user_email": key_doc["user_email"],
|
||||||
"name": key_doc["name"],
|
"name": key_doc["name"],
|
||||||
|
# RBAC fields - provide defaults for backward compatibility
|
||||||
|
"roles": key_doc.get("roles", ["openrag_user"]),
|
||||||
|
"groups": key_doc.get("groups", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -225,6 +248,8 @@ class APIKeyService:
|
||||||
"created_at",
|
"created_at",
|
||||||
"last_used_at",
|
"last_used_at",
|
||||||
"revoked",
|
"revoked",
|
||||||
|
"roles",
|
||||||
|
"groups",
|
||||||
],
|
],
|
||||||
"size": 100,
|
"size": 100,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ class ChatService:
|
||||||
previous_response_id: str = None,
|
previous_response_id: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
filter_id: str = None,
|
filter_id: str = None,
|
||||||
|
groups: list = None,
|
||||||
|
roles: list = None,
|
||||||
):
|
):
|
||||||
"""Handle chat requests using the patched OpenAI client"""
|
"""Handle chat requests using the patched OpenAI client"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
|
@ -23,7 +25,7 @@ class ChatService:
|
||||||
|
|
||||||
# Set authentication context for this request so tools can access it
|
# Set authentication context for this request so tools can access it
|
||||||
if user_id and jwt_token:
|
if user_id and jwt_token:
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return async_chat_stream(
|
return async_chat_stream(
|
||||||
|
|
@ -54,6 +56,8 @@ class ChatService:
|
||||||
previous_response_id: str = None,
|
previous_response_id: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
filter_id: str = None,
|
filter_id: str = None,
|
||||||
|
groups: list = None,
|
||||||
|
roles: list = None,
|
||||||
):
|
):
|
||||||
"""Handle Langflow chat requests"""
|
"""Handle Langflow chat requests"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
|
@ -346,9 +350,9 @@ class ChatService:
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
)
|
)
|
||||||
else: # chat
|
else: # chat
|
||||||
# Set auth context for chat tools and provide user_id
|
# Set auth context for chat tools and provide user_id with RBAC
|
||||||
if user_id and jwt_token:
|
if user_id and jwt_token:
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||||
response_text, response_id = await async_chat(
|
response_text, response_id = await async_chat(
|
||||||
clients.patched_llm_client,
|
clients.patched_llm_client,
|
||||||
document_prompt,
|
document_prompt,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import copy
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from agentd.tool_decorator import tool
|
from agentd.tool_decorator import tool
|
||||||
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
|
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
|
||||||
from auth_context import get_auth_context
|
from auth_context import get_auth_context, get_current_user_groups
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -259,8 +259,24 @@ class SearchService:
|
||||||
# Build query body
|
# Build query body
|
||||||
if is_wildcard_match_all:
|
if is_wildcard_match_all:
|
||||||
# Match all documents; still allow filters to narrow scope
|
# Match all documents; still allow filters to narrow scope
|
||||||
if filter_clauses:
|
# Also add RBAC group filter for wildcard queries
|
||||||
query_block = {"bool": {"filter": filter_clauses}}
|
wildcard_filters = list(filter_clauses) # Copy existing filters
|
||||||
|
|
||||||
|
user_groups = get_current_user_groups()
|
||||||
|
if user_groups:
|
||||||
|
groups_access_filter = {
|
||||||
|
"bool": {
|
||||||
|
"should": [
|
||||||
|
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||||
|
{"terms": {"allowed_groups": user_groups}},
|
||||||
|
],
|
||||||
|
"minimum_should_match": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wildcard_filters.append(groups_access_filter)
|
||||||
|
|
||||||
|
if wildcard_filters:
|
||||||
|
query_block = {"bool": {"filter": wildcard_filters}}
|
||||||
else:
|
else:
|
||||||
query_block = {"match_all": {}}
|
query_block = {"match_all": {}}
|
||||||
else:
|
else:
|
||||||
|
|
@ -292,6 +308,29 @@ class SearchService:
|
||||||
# Add exists filter to existing filters
|
# Add exists filter to existing filters
|
||||||
all_filters = [*filter_clauses, exists_any_embedding]
|
all_filters = [*filter_clauses, exists_any_embedding]
|
||||||
|
|
||||||
|
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
|
||||||
|
# Documents are accessible if:
|
||||||
|
# 1. No allowed_groups field exists (backward compatibility, open access)
|
||||||
|
# 2. User's groups match any of the document's allowed_groups
|
||||||
|
user_groups = get_current_user_groups()
|
||||||
|
if user_groups:
|
||||||
|
groups_access_filter = {
|
||||||
|
"bool": {
|
||||||
|
"should": [
|
||||||
|
# Document has no allowed_groups restriction
|
||||||
|
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||||
|
# User's groups match document's allowed_groups
|
||||||
|
{"terms": {"allowed_groups": user_groups}},
|
||||||
|
],
|
||||||
|
"minimum_should_match": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
all_filters.append(groups_access_filter)
|
||||||
|
logger.debug(
|
||||||
|
"Added RBAC group filter",
|
||||||
|
user_groups=user_groups,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Building hybrid query with filters",
|
"Building hybrid query with filters",
|
||||||
user_filters_count=len(filter_clauses),
|
user_filters_count=len(filter_clauses),
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import json
|
||||||
import jwt
|
import jwt
|
||||||
import httpx
|
import httpx
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any, List
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, field
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
import os
|
import os
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
@ -24,12 +24,19 @@ class User:
|
||||||
provider: str = "google"
|
provider: str = "google"
|
||||||
created_at: datetime = None
|
created_at: datetime = None
|
||||||
last_login: datetime = None
|
last_login: datetime = None
|
||||||
|
# RBAC fields
|
||||||
|
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
|
||||||
|
groups: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.created_at is None:
|
if self.created_at is None:
|
||||||
self.created_at = datetime.now()
|
self.created_at = datetime.now()
|
||||||
if self.last_login is None:
|
if self.last_login is None:
|
||||||
self.last_login = datetime.now()
|
self.last_login = datetime.now()
|
||||||
|
if self.roles is None:
|
||||||
|
self.roles = ["openrag_user"]
|
||||||
|
if self.groups is None:
|
||||||
|
self.groups = []
|
||||||
|
|
||||||
class AnonymousUser(User):
|
class AnonymousUser(User):
|
||||||
"""Anonymous user"""
|
"""Anonymous user"""
|
||||||
|
|
@ -136,11 +143,31 @@ class SessionManager:
|
||||||
# Create JWT token using the shared method
|
# Create JWT token using the shared method
|
||||||
return self.create_jwt_token(user)
|
return self.create_jwt_token(user)
|
||||||
|
|
||||||
def create_jwt_token(self, user: User) -> str:
|
def create_jwt_token(
|
||||||
"""Create JWT token for an existing user"""
|
self,
|
||||||
|
user: User,
|
||||||
|
roles: Optional[List[str]] = None,
|
||||||
|
groups: Optional[List[str]] = None,
|
||||||
|
expiration_days: int = 7,
|
||||||
|
) -> str:
|
||||||
|
"""Create JWT token for an existing user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The User object to create a token for
|
||||||
|
roles: Optional roles override (for restricted API key tokens)
|
||||||
|
groups: Optional groups override (for restricted API key tokens)
|
||||||
|
expiration_days: Token expiration in days (default 7)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded JWT token string
|
||||||
|
"""
|
||||||
# Use OpenSearch-compatible issuer for OIDC validation
|
# Use OpenSearch-compatible issuer for OIDC validation
|
||||||
oidc_issuer = "http://openrag-backend:8000"
|
oidc_issuer = "http://openrag-backend:8000"
|
||||||
|
|
||||||
|
# Use provided roles/groups or fall back to user's defaults
|
||||||
|
effective_roles = roles if roles is not None else user.roles
|
||||||
|
effective_groups = groups if groups is not None else user.groups
|
||||||
|
|
||||||
# Create JWT token with OIDC-compliant claims
|
# Create JWT token with OIDC-compliant claims
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
token_payload = {
|
token_payload = {
|
||||||
|
|
@ -148,7 +175,7 @@ class SessionManager:
|
||||||
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
||||||
"sub": user.user_id, # Subject (user ID)
|
"sub": user.user_id, # Subject (user ID)
|
||||||
"aud": ["opensearch", "openrag"], # Audience
|
"aud": ["opensearch", "openrag"], # Audience
|
||||||
"exp": now + timedelta(days=7), # Expiration
|
"exp": now + timedelta(days=expiration_days), # Expiration
|
||||||
"iat": now, # Issued at
|
"iat": now, # Issued at
|
||||||
"auth_time": int(now.timestamp()), # Authentication time
|
"auth_time": int(now.timestamp()), # Authentication time
|
||||||
# Custom claims
|
# Custom claims
|
||||||
|
|
@ -157,7 +184,9 @@ class SessionManager:
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"preferred_username": user.email,
|
"preferred_username": user.email,
|
||||||
"email_verified": True,
|
"email_verified": True,
|
||||||
"roles": ["openrag_user"], # Backend role for OpenSearch
|
# RBAC claims
|
||||||
|
"roles": effective_roles, # Backend roles for OpenSearch/tools
|
||||||
|
"groups": effective_groups, # Group-based access control
|
||||||
}
|
}
|
||||||
|
|
||||||
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue