Compare commits

...
Sign in to create a new pull request.

2 commits
main ... RBAC

Author SHA1 Message Date
Edwin Jose
6faa77d5c7 add user groups 2025-12-26 17:07:20 -05:00
Edwin Jose
ce51628db2 RBAC basic implementation 2025-12-26 16:41:39 -05:00
21 changed files with 1139 additions and 98 deletions

View file

@ -6,6 +6,8 @@ import {
export interface CreateApiKeyRequest {
name: string;
roles?: string[];
groups?: string[];
}
export interface CreateApiKeyResponse {
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
name: string;
key_prefix: string;
created_at: string;
roles?: string[];
groups?: string[];
}
export const useCreateApiKeyMutation = (

View file

@ -0,0 +1,61 @@
import {
type UseMutationOptions,
useMutation,
useQueryClient,
} from "@tanstack/react-query";
export interface CreateGroupRequest {
name: string;
description?: string;
}
export interface CreateGroupResponse {
success: boolean;
group_id: string;
name: string;
description: string;
created_at: string;
error?: string;
}
export const useCreateGroupMutation = (
options?: Omit<
UseMutationOptions<CreateGroupResponse, Error, CreateGroupRequest>,
"mutationFn"
>,
) => {
const queryClient = useQueryClient();
async function createGroup(
variables: CreateGroupRequest,
): Promise<CreateGroupResponse> {
const response = await fetch("/api/groups", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(variables),
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.error || "Failed to create group");
}
return data;
}
return useMutation({
mutationFn: createGroup,
onSuccess: (...args) => {
queryClient.invalidateQueries({
queryKey: ["groups"],
});
options?.onSuccess?.(...args);
},
onError: options?.onError,
onSettled: options?.onSettled,
});
};

View file

@ -0,0 +1,52 @@
import {
type UseMutationOptions,
useMutation,
useQueryClient,
} from "@tanstack/react-query";
export interface DeleteGroupRequest {
group_id: string;
}
export interface DeleteGroupResponse {
success: boolean;
error?: string;
}
export const useDeleteGroupMutation = (
options?: Omit<
UseMutationOptions<DeleteGroupResponse, Error, DeleteGroupRequest>,
"mutationFn"
>,
) => {
const queryClient = useQueryClient();
async function deleteGroup(
variables: DeleteGroupRequest,
): Promise<DeleteGroupResponse> {
const response = await fetch(`/api/groups/${variables.group_id}`, {
method: "DELETE",
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.error || "Failed to delete group");
}
return data;
}
return useMutation({
mutationFn: deleteGroup,
onSuccess: (...args) => {
queryClient.invalidateQueries({
queryKey: ["groups"],
});
options?.onSuccess?.(...args);
},
onError: options?.onError,
onSettled: options?.onSettled,
});
};

View file

@ -6,6 +6,8 @@ export interface ApiKey {
key_prefix: string;
created_at: string;
last_used_at: string | null;
roles?: string[];
groups?: string[];
}
export interface GetApiKeysResponse {

View file

@ -0,0 +1,32 @@
import { type UseQueryOptions, useQuery } from "@tanstack/react-query";
export interface Group {
group_id: string;
name: string;
description: string;
created_at: string;
}
export interface GetGroupsResponse {
success: boolean;
groups: Group[];
}
export const useGetGroupsQuery = (
options?: Omit<UseQueryOptions<GetGroupsResponse>, "queryKey" | "queryFn">,
) => {
async function getGroups(): Promise<GetGroupsResponse> {
const response = await fetch("/api/groups");
if (response.ok) {
return await response.json();
}
throw new Error("Failed to fetch groups");
}
return useQuery({
queryKey: ["groups"],
queryFn: getGroups,
...options,
});
};

View file

@ -12,8 +12,11 @@ import {
useGetOpenAIModelsQuery,
} from "@/app/api/queries/useGetModelsQuery";
import { useGetApiKeysQuery } from "@/app/api/queries/useGetApiKeysQuery";
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
import { useCreateApiKeyMutation } from "@/app/api/mutations/useCreateApiKeyMutation";
import { useRevokeApiKeyMutation } from "@/app/api/mutations/useRevokeApiKeyMutation";
import { ManageGroupsModal } from "@/components/ManageGroupsModal";
import { MultiSelect } from "@/components/ui/multi-select";
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
import { ConfirmationDialog } from "@/components/confirmation-dialog";
import {
@ -136,8 +139,10 @@ function KnowledgeSourcesPage() {
// API Keys state
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
const [newKeyName, setNewKeyName] = useState("");
const [newKeyGroups, setNewKeyGroups] = useState<string[]>([]);
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
const [manageGroupsOpen, setManageGroupsOpen] = useState(false);
// Fetch settings using React Query
const { data: settings = {} } = useGetSettingsQuery({
@ -149,6 +154,11 @@ function KnowledgeSourcesPage() {
enabled: isAuthenticated || isNoAuthMode,
});
// Fetch user groups
const { data: groupsData } = useGetGroupsQuery({
enabled: isAuthenticated || isNoAuthMode,
});
// API key mutations
const createApiKeyMutation = useCreateApiKeyMutation({
onSuccess: (data) => {
@ -156,6 +166,7 @@ function KnowledgeSourcesPage() {
setCreateKeyDialogOpen(false);
setShowKeyDialogOpen(true);
setNewKeyName("");
setNewKeyGroups([]);
toast.success("API key created");
},
onError: (error) => {
@ -438,7 +449,11 @@ function KnowledgeSourcesPage() {
toast.error("Please enter a name for the API key");
return;
}
createApiKeyMutation.mutate({ name: newKeyName.trim() });
// Use selected groups directly (already an array)
createApiKeyMutation.mutate({
name: newKeyName.trim(),
groups: newKeyGroups.length > 0 ? newKeyGroups : undefined,
});
};
const handleRevokeApiKey = (keyId: string) => {
@ -1426,6 +1441,9 @@ function KnowledgeSourcesPage() {
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Key
</th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Groups
</th>
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
Created
</th>
@ -1448,6 +1466,24 @@ function KnowledgeSourcesPage() {
{key.key_prefix}...
</code>
</td>
<td className="px-4 py-3 text-sm text-muted-foreground">
{key.groups && key.groups.length > 0 ? (
<div className="flex flex-wrap gap-1">
{key.groups.map((group: string) => (
<span
key={group}
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
>
{group}
</span>
))}
</div>
) : (
<span className="text-muted-foreground/50">
All groups
</span>
)}
</td>
<td className="px-4 py-3 text-sm text-muted-foreground">
{formatDate(key.created_at)}
</td>
@ -1507,58 +1543,88 @@ function KnowledgeSourcesPage() {
</Card>
)}
{/* Create API Key Dialog */}
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create API Key</DialogTitle>
<DialogDescription>
Give your API key a name to help you identify it later.
</DialogDescription>
</DialogHeader>
<div className="py-4">
<LabelWrapper label="Name" id="api-key-name">
<Input
id="api-key-name"
placeholder="e.g., Production App, Development"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateApiKey();
}
}}
/>
</LabelWrapper>
</div>
<DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
{/* Create API Key Dialog */}
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>Create API Key</DialogTitle>
<DialogDescription>
Create an API key with optional group restrictions for access
control.
</DialogDescription>
</DialogHeader>
<div className="py-4 space-y-4">
<LabelWrapper label="Name" id="api-key-name">
<Input
id="api-key-name"
placeholder="e.g., Production App, Development"
value={newKeyName}
onChange={(e) => setNewKeyName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateApiKey();
}
}}
size="sm"
>
Cancel
</Button>
<Button
onClick={handleCreateApiKey}
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
size="sm"
>
{createApiKeyMutation.isPending ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create Key"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
/>
</LabelWrapper>
<div className="space-y-2">
<LabelWrapper
label="Groups (optional)"
id="api-key-groups"
helperText="Restrict this key to specific groups"
>
<MultiSelect
options={(groupsData?.groups || []).map((g) => ({
value: g.name,
label: g.name,
}))}
value={newKeyGroups}
onValueChange={setNewKeyGroups}
placeholder="Select groups..."
showAllOption={false}
searchPlaceholder="Search groups..."
/>
</LabelWrapper>
<Button
variant="link"
size="sm"
className="h-auto p-0 text-xs"
onClick={() => setManageGroupsOpen(true)}
>
<Plus className="h-3 w-3 mr-1" />
Manage Groups
</Button>
</div>
</div>
<DialogFooter>
<Button
variant="ghost"
onClick={() => {
setCreateKeyDialogOpen(false);
setNewKeyName("");
setNewKeyGroups([]);
}}
size="sm"
>
Cancel
</Button>
<Button
onClick={handleCreateApiKey}
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
size="sm"
>
{createApiKeyMutation.isPending ? (
<>
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
Creating...
</>
) : (
"Create Key"
)}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* Show Created API Key Dialog */}
<Dialog
@ -1593,6 +1659,12 @@ function KnowledgeSourcesPage() {
</DialogFooter>
</DialogContent>
</Dialog>
{/* Manage Groups Modal */}
<ManageGroupsModal
open={manageGroupsOpen}
onOpenChange={setManageGroupsOpen}
/>
</div>
);
}

View file

@ -0,0 +1,175 @@
"use client";
import { useState } from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
import { useCreateGroupMutation } from "@/app/api/mutations/useCreateGroupMutation";
import { useDeleteGroupMutation } from "@/app/api/mutations/useDeleteGroupMutation";
import { Plus, Trash2, Loader2, Users } from "lucide-react";
import { toast } from "sonner";
interface ManageGroupsModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
}
export function ManageGroupsModal({
open,
onOpenChange,
}: ManageGroupsModalProps) {
const [newGroupName, setNewGroupName] = useState("");
const [newGroupDescription, setNewGroupDescription] = useState("");
const { data: groupsData, isLoading: groupsLoading } = useGetGroupsQuery({
enabled: open,
});
const createGroupMutation = useCreateGroupMutation({
onSuccess: () => {
setNewGroupName("");
setNewGroupDescription("");
toast.success("Group created successfully");
},
onError: (error) => {
toast.error("Failed to create group", { description: error.message });
},
});
const deleteGroupMutation = useDeleteGroupMutation({
onSuccess: () => {
toast.success("Group deleted successfully");
},
onError: (error) => {
toast.error("Failed to delete group", { description: error.message });
},
});
const handleCreateGroup = () => {
if (!newGroupName.trim()) {
toast.error("Please enter a group name");
return;
}
createGroupMutation.mutate({
name: newGroupName.trim(),
description: newGroupDescription.trim(),
});
};
const handleDeleteGroup = (groupId: string, groupName: string) => {
if (confirm(`Are you sure you want to delete the group "${groupName}"?`)) {
deleteGroupMutation.mutate({ group_id: groupId });
}
};
const groups = groupsData?.groups || [];
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-md">
<DialogHeader>
<DialogTitle>Manage User Groups</DialogTitle>
<DialogDescription>
Create and manage user groups for access control. Groups can be
assigned to API keys to restrict document access.
</DialogDescription>
</DialogHeader>
<div className="space-y-4 py-4">
{/* Add new group section */}
<div className="space-y-2">
<label className="text-sm font-medium">Create New Group</label>
<div className="flex gap-2">
<Input
placeholder="Group name"
value={newGroupName}
onChange={(e) => setNewGroupName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter") {
handleCreateGroup();
}
}}
className="flex-1"
/>
<Button
onClick={handleCreateGroup}
disabled={
createGroupMutation.isPending || !newGroupName.trim()
}
size="sm"
>
{createGroupMutation.isPending ? (
<Loader2 className="h-4 w-4 animate-spin" />
) : (
<Plus className="h-4 w-4" />
)}
</Button>
</div>
<Input
placeholder="Description (optional)"
value={newGroupDescription}
onChange={(e) => setNewGroupDescription(e.target.value)}
className="text-sm"
/>
</div>
{/* Existing groups list */}
<div className="space-y-2">
<label className="text-sm font-medium">Existing Groups</label>
{groupsLoading ? (
<div className="flex items-center justify-center py-4">
<Loader2 className="h-5 w-5 animate-spin text-muted-foreground" />
</div>
) : groups.length > 0 ? (
<div className="border rounded-lg divide-y max-h-48 overflow-y-auto">
{groups.map((group) => (
<div
key={group.group_id}
className="flex items-center justify-between px-3 py-2 hover:bg-muted/50"
>
<div className="flex-1 min-w-0">
<p className="text-sm font-medium truncate">
{group.name}
</p>
{group.description && (
<p className="text-xs text-muted-foreground truncate">
{group.description}
</p>
)}
</div>
<Button
variant="ghost"
size="sm"
className="h-8 w-8 p-0 text-destructive hover:text-destructive hover:bg-destructive/10"
onClick={() =>
handleDeleteGroup(group.group_id, group.name)
}
disabled={deleteGroupMutation.isPending}
>
<Trash2 className="h-4 w-4" />
</Button>
</div>
))}
</div>
) : (
<div className="text-center py-6 border rounded-lg">
<Users className="h-8 w-8 mx-auto text-muted-foreground/50 mb-2" />
<p className="text-sm text-muted-foreground">
No groups yet. Create one above.
</p>
</div>
)}
</div>
</div>
</DialogContent>
</Dialog>
);
}

View file

@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
set_search_limit(limit)
set_score_threshold(score_threshold)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if stream:
return StreamingResponse(
await chat_service.chat(
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
),
media_type="text/event-stream",
headers={
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
return JSONResponse(result)
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
if not prompt:
return JSONResponse({"error": "Prompt is required"}, status_code=400)
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
),
media_type="text/event-stream",
headers={
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
previous_response_id=previous_response_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
return JSONResponse(result)

114
src/api/groups.py Normal file
View file

@ -0,0 +1,114 @@
"""
User Groups management endpoints.
These endpoints allow managing user groups for RBAC.
"""
from starlette.requests import Request
from starlette.responses import JSONResponse
from utils.logging_config import get_logger
logger = get_logger(__name__)
async def list_groups_endpoint(request: Request, group_service):
"""
List all user groups.
GET /groups
Response:
{
"success": true,
"groups": [
{
"group_id": "...",
"name": "finance",
"description": "Finance team",
"created_at": "2024-01-01T00:00:00"
}
]
}
"""
result = await group_service.list_groups()
return JSONResponse(result)
async def create_group_endpoint(request: Request, group_service):
"""
Create a new user group.
POST /groups
Body: {"name": "finance", "description": "Finance team"}
Response:
{
"success": true,
"group_id": "...",
"name": "finance",
"description": "Finance team",
"created_at": "2024-01-01T00:00:00"
}
"""
try:
data = await request.json()
name = data.get("name", "").strip()
description = data.get("description", "").strip()
if not name:
return JSONResponse(
{"success": False, "error": "Name is required"},
status_code=400,
)
if len(name) > 100:
return JSONResponse(
{"success": False, "error": "Name must be 100 characters or less"},
status_code=400,
)
result = await group_service.create_group(
name=name,
description=description,
)
if result.get("success"):
return JSONResponse(result)
elif "already exists" in result.get("error", ""):
return JSONResponse(result, status_code=409)
else:
return JSONResponse(result, status_code=500)
except Exception as e:
logger.error(f"Failed to create group: {e}")
return JSONResponse(
{"success": False, "error": str(e)},
status_code=500,
)
async def delete_group_endpoint(request: Request, group_service):
"""
Delete a user group.
DELETE /groups/{group_id}
Response:
{"success": true}
"""
group_id = request.path_params.get("group_id")
if not group_id:
return JSONResponse(
{"success": False, "error": "Group ID is required"},
status_code=400,
)
result = await group_service.delete_group(group_id=group_id)
if result.get("success"):
return JSONResponse(result)
elif result.get("error") == "Group not found":
return JSONResponse(result, status_code=404)
else:
return JSONResponse(result, status_code=500)

View file

@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
async def create_key_endpoint(request: Request, api_key_service):
"""
Create a new API key for the authenticated user.
Create a new API key for the authenticated user with optional RBAC restrictions.
POST /keys
Body: {"name": "My API Key"}
Body: {
"name": "My API Key",
"roles": ["openrag_user"], // Optional: restrict key to specific roles
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
}
Response:
{
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
"key_id": "...",
"key_prefix": "orag_abc12345",
"name": "My API Key",
"roles": ["openrag_user"],
"groups": ["finance", "hr"],
"created_at": "2024-01-01T00:00:00",
"api_key": "orag_abc12345..." // Full key, only shown once!
}
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
status_code=400,
)
# Extract optional RBAC fields
roles = data.get("roles")
groups = data.get("groups")
# Validate roles and groups are lists if provided
if roles is not None and not isinstance(roles, list):
return JSONResponse(
{"success": False, "error": "roles must be a list"},
status_code=400,
)
if groups is not None and not isinstance(groups, list):
return JSONResponse(
{"success": False, "error": "groups must be a list"},
status_code=400,
)
result = await api_key_service.create_key(
user_id=user_id,
user_email=user_email,
name=name,
jwt_token=jwt_token,
roles=roles,
groups=groups,
)
if result.get("success"):

View file

@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
user = request.state.user
user_id = user.user_id
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
# Get RBAC groups and roles from user for access control
user_groups = getattr(user, "groups", [])
user_roles = getattr(user, "roles", [])
# Set context variables for search tool
if filters:
set_search_filters(filters)
set_search_limit(limit)
set_score_threshold(score_threshold)
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
if stream:
raw_stream = await chat_service.langflow_chat(
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id,
stream=True,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
chat_id_container = {}
return StreamingResponse(
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
previous_response_id=chat_id,
stream=False,
filter_id=filter_id,
groups=user_groups,
roles=user_roles,
)
# Transform response_id to chat_id for v1 API format
return JSONResponse({

View file

@ -1,5 +1,8 @@
"""
API Key middleware for authenticating public API requests.
This middleware validates API keys and generates ephemeral JWTs with the
key's specific roles and groups for downstream security enforcement.
"""
from starlette.requests import Request
from starlette.responses import JSONResponse
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
return None
def require_api_key(api_key_service):
def require_api_key(api_key_service, session_manager=None):
"""
Decorator to require API key authentication for public API endpoints.
Generates an ephemeral JWT with the API key's specific roles and groups
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
Usage:
@require_api_key(api_key_service)
@require_api_key(api_key_service, session_manager)
async def my_endpoint(request):
user = request.state.user
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
...
"""
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
status_code=401,
)
# Validate the key
# Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key)
if not user_info:
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
status_code=401,
)
# Create a User object from the API key info
# Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User(
user_id=user_info["user_id"],
email=user_info["user_email"],
name=user_info.get("name", "API User"),
picture=None,
provider="api_key",
roles=key_roles,
groups=key_groups,
)
# Set request state
request.state.user = user
request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None # No JWT for API key auth
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
# Create a short-lived JWT with the key's specific permissions
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1, # Short-lived for API requests
)
request.state.jwt_token = ephemeral_jwt
logger.debug(
"Generated ephemeral JWT for API key",
key_id=user_info["key_id"],
roles=key_roles,
groups=key_groups,
)
else:
request.state.jwt_token = None
logger.warning(
"No session_manager provided - JWT not generated for API key"
)
return await handler(request)
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
return decorator
def optional_api_key(api_key_service):
def optional_api_key(api_key_service, session_manager=None):
"""
Decorator to optionally authenticate with API key.
Sets request.state.user to None if no valid API key is provided.
When a valid API key is provided, generates an ephemeral JWT with
the key's specific roles and groups.
"""
def decorator(handler):
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
api_key = _extract_api_key(request)
if api_key:
# Validate the key
# Validate the key and get RBAC claims
user_info = await api_key_service.validate_key(api_key)
if user_info:
# Create a User object from the API key info
# Extract RBAC fields from API key
key_roles = user_info.get("roles", ["openrag_user"])
key_groups = user_info.get("groups", [])
# Create a User object with the API key's roles and groups
user = User(
user_id=user_info["user_id"],
email=user_info["user_email"],
name=user_info.get("name", "API User"),
picture=None,
provider="api_key",
roles=key_roles,
groups=key_groups,
)
request.state.user = user
request.state.api_key_id = user_info["key_id"]
request.state.jwt_token = None
# Generate ephemeral JWT with the API key's restricted roles/groups
if session_manager:
ephemeral_jwt = session_manager.create_jwt_token(
user=user,
roles=key_roles,
groups=key_groups,
expiration_days=1,
)
request.state.jwt_token = ephemeral_jwt
else:
request.state.jwt_token = None
else:
request.state.user = None
request.state.api_key_id = None

View file

@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
"""
from contextvars import ContextVar
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List
# Context variables for current request authentication
_current_user_id: ContextVar[Optional[str]] = ContextVar(
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
"current_jwt_token", default=None
)
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_groups", default=None
)
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
"current_user_roles", default=None
)
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
"current_search_filters", default=None
)
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
)
def set_auth_context(user_id: str, jwt_token: str):
"""Set authentication context for the current async context"""
def set_auth_context(
user_id: str,
jwt_token: str,
groups: Optional[List[str]] = None,
roles: Optional[List[str]] = None,
):
"""Set authentication context for the current async context
Args:
user_id: The user's ID
jwt_token: The JWT token for authentication
groups: Optional list of groups the user belongs to (for RBAC)
roles: Optional list of roles the user has (for RBAC)
"""
_current_user_id.set(user_id)
_current_jwt_token.set(jwt_token)
_current_user_groups.set(groups or [])
_current_user_roles.set(roles or [])
def get_current_user_id() -> Optional[str]:
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
return _current_jwt_token.get()
def get_current_user_groups() -> List[str]:
"""Get current user's groups from context (for RBAC)"""
return _current_user_groups.get() or []
def get_current_user_roles() -> List[str]:
"""Get current user's roles from context (for RBAC)"""
return _current_user_roles.get() or []
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
"""Get current authentication context (user_id, jwt_token)"""
return _current_user_id.get(), _current_jwt_token.get()

View file

@ -165,6 +165,23 @@ API_KEYS_INDEX_BODY = {
},
}
# User Groups index for RBAC management
GROUPS_INDEX_NAME = "openrag_groups"
GROUPS_INDEX_BODY = {
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0,
},
"mappings": {
"properties": {
"group_id": {"type": "keyword"},
"name": {"type": "keyword"},
"description": {"type": "text"},
"created_at": {"type": "date"},
}
},
}
# Convenience base URL for Langflow REST API
LANGFLOW_BASE_URL = f"{LANGFLOW_URL}/api/v1"

View file

@ -58,6 +58,10 @@ from auth_middleware import optional_auth, require_auth
from api_key_middleware import require_api_key
from services.api_key_service import APIKeyService
from api import keys as api_keys
# User Groups management
from services.group_service import GroupService
from api import groups as api_groups
from api.v1 import chat as v1_chat, search as v1_search, documents as v1_documents, settings as v1_settings, knowledge_filters as v1_knowledge_filters
# Configuration and setup
@ -665,6 +669,9 @@ async def initialize_services():
# API Key service for public API authentication
api_key_service = APIKeyService(session_manager)
# Group service for RBAC management
group_service = GroupService(session_manager)
return {
"document_service": document_service,
"search_service": search_service,
@ -679,6 +686,7 @@ async def initialize_services():
"monitor_service": monitor_service,
"session_manager": session_manager,
"api_key_service": api_key_service,
"group_service": group_service,
}
@ -1310,11 +1318,42 @@ async def create_app():
),
methods=["DELETE"],
),
# ===== User Groups Management Endpoints (JWT auth for UI) =====
Route(
"/groups",
require_auth(services["session_manager"])(
partial(
api_groups.list_groups_endpoint,
group_service=services["group_service"],
)
),
methods=["GET"],
),
Route(
"/groups",
require_auth(services["session_manager"])(
partial(
api_groups.create_group_endpoint,
group_service=services["group_service"],
)
),
methods=["POST"],
),
Route(
"/groups/{group_id}",
require_auth(services["session_manager"])(
partial(
api_groups.delete_group_endpoint,
group_service=services["group_service"],
)
),
methods=["DELETE"],
),
# ===== Public API v1 Endpoints (API Key auth) =====
# Chat endpoints
Route(
"/v1/chat",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_create_endpoint,
chat_service=services["chat_service"],
@ -1325,7 +1364,7 @@ async def create_app():
),
Route(
"/v1/chat",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_list_endpoint,
chat_service=services["chat_service"],
@ -1336,7 +1375,7 @@ async def create_app():
),
Route(
"/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_get_endpoint,
chat_service=services["chat_service"],
@ -1347,7 +1386,7 @@ async def create_app():
),
Route(
"/v1/chat/{chat_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_chat.chat_delete_endpoint,
chat_service=services["chat_service"],
@ -1359,7 +1398,7 @@ async def create_app():
# Search endpoint
Route(
"/v1/search",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_search.search_endpoint,
search_service=services["search_service"],
@ -1371,7 +1410,7 @@ async def create_app():
# Documents endpoints
Route(
"/v1/documents/ingest",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.ingest_endpoint,
document_service=services["document_service"],
@ -1384,7 +1423,7 @@ async def create_app():
),
Route(
"/v1/tasks/{task_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.task_status_endpoint,
task_service=services["task_service"],
@ -1395,7 +1434,7 @@ async def create_app():
),
Route(
"/v1/documents",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_documents.delete_document_endpoint,
document_service=services["document_service"],
@ -1407,14 +1446,14 @@ async def create_app():
# Settings endpoints
Route(
"/v1/settings",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(v1_settings.get_settings_endpoint)
),
methods=["GET"],
),
Route(
"/v1/settings",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_settings.update_settings_endpoint,
session_manager=services["session_manager"],
@ -1425,7 +1464,7 @@ async def create_app():
# Knowledge filters endpoints
Route(
"/v1/knowledge-filters",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.create_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1436,7 +1475,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/search",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.search_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1447,7 +1486,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.get_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1458,7 +1497,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.update_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],
@ -1469,7 +1508,7 @@ async def create_app():
),
Route(
"/v1/knowledge-filters/{filter_id}",
require_api_key(services["api_key_service"])(
require_api_key(services["api_key_service"], services["session_manager"])(
partial(
v1_knowledge_filters.delete_endpoint,
knowledge_filter_service=services["knowledge_filter_service"],

View file

@ -158,6 +158,7 @@ class TaskProcessor:
connector_type: str = "local",
embedding_model: str = None,
is_sample_data: bool = False,
allowed_groups: list = None,
):
"""
Standard processing pipeline for non-Langflow processors:
@ -166,6 +167,8 @@ class TaskProcessor:
Args:
embedding_model: Embedding model to use (defaults to the current
embedding model from settings)
allowed_groups: List of groups that can access this document (RBAC).
Empty list or None means no group restrictions.
"""
import datetime
from config.settings import INDEX_NAME, clients, get_embedding_model
@ -259,6 +262,11 @@ class TaskProcessor:
if owner_email is not None:
chunk_doc["owner_email"] = owner_email
# RBAC: Set allowed groups for access control
# If allowed_groups is provided and non-empty, store it for DLS filtering
if allowed_groups:
chunk_doc["allowed_groups"] = allowed_groups
# Mark as sample data if specified
if is_sample_data:
chunk_doc["is_sample_data"] = "true"
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
owner_name: str = None,
owner_email: str = None,
is_sample_data: bool = False,
allowed_groups: list = None,
):
super().__init__(document_service)
self.owner_user_id = owner_user_id
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
self.owner_name = owner_name
self.owner_email = owner_email
self.is_sample_data = is_sample_data
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
file_size=file_size,
connector_type="local",
is_sample_data=self.is_sample_data,
allowed_groups=self.allowed_groups,
)
file_task.status = TaskStatus.COMPLETED
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_name: str = None,
owner_email: str = None,
document_service=None,
allowed_groups: list = None,
):
super().__init__(document_service=document_service)
self.connector_service = connector_service
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
owner_email=self.owner_email,
file_size=len(document.content),
connector_type=connection.connector_type,
allowed_groups=self.allowed_groups,
)
# Add connector-specific metadata
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
jwt_token: str = None,
owner_name: str = None,
owner_email: str = None,
allowed_groups: list = None,
):
super().__init__()
self.langflow_connector_service = langflow_connector_service
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
jwt_token: str = None,
owner_name: str = None,
owner_email: str = None,
allowed_groups: list = None,
):
import boto3
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
self.jwt_token = jwt_token
self.owner_name = owner_name
self.owner_email = owner_email
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
owner_email=self.owner_email,
file_size=file_size,
connector_type="s3",
allowed_groups=self.allowed_groups,
)
result["path"] = f"s3://{self.bucket}/{item}"
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
settings: dict = None,
delete_after_ingest: bool = True,
replace_duplicates: bool = False,
allowed_groups: list = None,
):
super().__init__()
self.langflow_file_service = langflow_file_service
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
self.settings = settings
self.delete_after_ingest = delete_after_ingest
self.replace_duplicates = replace_duplicates
self.allowed_groups = allowed_groups
async def process_item(
self, upload_task: UploadTask, item: str, file_task: FileTask
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
# Mark as local upload for connector_type
metadata_tweaks.append({"key": "connector_type", "value": "local"})
# RBAC: Add allowed_groups for access control
if self.allowed_groups:
# Store as comma-separated string for Langflow metadata
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
if metadata_tweaks:
# Initialize the OpenSearch component tweaks if not already present

View file

@ -52,15 +52,19 @@ class APIKeyService:
user_email: str,
name: str,
jwt_token: str = None,
roles: List[str] = None,
groups: List[str] = None,
) -> Dict[str, Any]:
"""
Create a new API key for a user.
Create a new API key for a user with optional RBAC restrictions.
Args:
user_id: The user's ID
user_email: The user's email
name: A friendly name for the key
jwt_token: JWT token for OpenSearch authentication
roles: Optional list of roles to restrict this key to
groups: Optional list of groups this key can access
Returns:
Dict with success status, key info, and the full key (only shown once)
@ -74,6 +78,14 @@ class APIKeyService:
now = datetime.utcnow().isoformat()
# Default roles if not specified
if roles is None:
roles = ["openrag_user"]
# Default groups to empty if not specified
if groups is None:
groups = []
# Create the document to store
key_doc = {
"key_id": key_id,
@ -85,6 +97,9 @@ class APIKeyService:
"created_at": now,
"last_used_at": None,
"revoked": False,
# RBAC fields
"roles": roles,
"groups": groups,
}
# Get OpenSearch client
@ -105,6 +120,8 @@ class APIKeyService:
user_id=user_id,
key_id=key_id,
key_prefix=key_prefix,
roles=roles,
groups=groups,
)
return {
"success": True,
@ -112,6 +129,8 @@ class APIKeyService:
"key_prefix": key_prefix,
"name": name,
"created_at": now,
"roles": roles,
"groups": groups,
"api_key": full_key, # Only returned once!
}
else:
@ -123,13 +142,13 @@ class APIKeyService:
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
"""
Validate an API key and return user info if valid.
Validate an API key and return user info with RBAC claims if valid.
Args:
api_key: The full API key to validate
Returns:
Dict with user info if valid, None if invalid
Dict with user info including roles and groups if valid, None if invalid
"""
try:
# Check key format
@ -181,11 +200,15 @@ class APIKeyService:
except Exception:
pass # Don't fail validation if update fails
# Return user info with RBAC claims
return {
"key_id": key_doc["key_id"],
"user_id": key_doc["user_id"],
"user_email": key_doc["user_email"],
"name": key_doc["name"],
# RBAC fields - provide defaults for backward compatibility
"roles": key_doc.get("roles", ["openrag_user"]),
"groups": key_doc.get("groups", []),
}
except Exception as e:
@ -225,6 +248,8 @@ class APIKeyService:
"created_at",
"last_used_at",
"revoked",
"roles",
"groups",
],
"size": 100,
}

View file

@ -16,6 +16,8 @@ class ChatService:
previous_response_id: str = None,
stream: bool = False,
filter_id: str = None,
groups: list = None,
roles: list = None,
):
"""Handle chat requests using the patched OpenAI client"""
if not prompt:
@ -23,7 +25,7 @@ class ChatService:
# Set authentication context for this request so tools can access it
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
if stream:
return async_chat_stream(
@ -54,6 +56,8 @@ class ChatService:
previous_response_id: str = None,
stream: bool = False,
filter_id: str = None,
groups: list = None,
roles: list = None,
):
"""Handle Langflow chat requests"""
if not prompt:
@ -346,9 +350,9 @@ class ChatService:
previous_response_id=previous_response_id,
)
else: # chat
# Set auth context for chat tools and provide user_id
# Set auth context for chat tools and provide user_id with RBAC
if user_id and jwt_token:
set_auth_context(user_id, jwt_token)
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
response_text, response_id = await async_chat(
clients.patched_llm_client,
document_prompt,

View file

@ -0,0 +1,219 @@
"""
Group Service for managing user groups for RBAC.
"""
import secrets
from datetime import datetime
from typing import Any, Dict, List, Optional
from config.settings import GROUPS_INDEX_NAME
from utils.logging_config import get_logger
logger = get_logger(__name__)
class GroupService:
"""Service for managing user groups for RBAC."""
def __init__(self, session_manager=None):
self.session_manager = session_manager
async def _ensure_index_exists(self, opensearch_client) -> None:
"""Ensure the groups index exists."""
from config.settings import GROUPS_INDEX_BODY
try:
exists = await opensearch_client.indices.exists(index=GROUPS_INDEX_NAME)
if not exists:
await opensearch_client.indices.create(
index=GROUPS_INDEX_NAME,
body=GROUPS_INDEX_BODY,
)
logger.info(f"Created groups index: {GROUPS_INDEX_NAME}")
except Exception as e:
# Index might already exist from concurrent creation
if "resource_already_exists_exception" not in str(e):
logger.error(f"Failed to create groups index: {e}")
raise
async def create_group(
self,
name: str,
description: str = "",
) -> Dict[str, Any]:
"""
Create a new user group.
Args:
name: The group name (must be unique)
description: Optional description of the group
Returns:
Dict with success status and group info
"""
try:
# Get OpenSearch client
from config.settings import clients
opensearch_client = clients.opensearch
# Ensure index exists
await self._ensure_index_exists(opensearch_client)
# Check if group with this name already exists
search_body = {
"query": {"term": {"name": name}},
"size": 1,
}
result = await opensearch_client.search(
index=GROUPS_INDEX_NAME,
body=search_body,
)
if result.get("hits", {}).get("hits", []):
return {"success": False, "error": f"Group '{name}' already exists"}
# Create a unique group_id
group_id = secrets.token_urlsafe(16)
now = datetime.utcnow().isoformat()
# Create the document to store
group_doc = {
"group_id": group_id,
"name": name,
"description": description,
"created_at": now,
}
# Index the group document
result = await opensearch_client.index(
index=GROUPS_INDEX_NAME,
id=group_id,
body=group_doc,
refresh="wait_for",
)
if result.get("result") in ("created", "updated"):
logger.info(f"Created group: {name} (id: {group_id})")
return {
"success": True,
"group_id": group_id,
"name": name,
"description": description,
"created_at": now,
}
else:
return {"success": False, "error": "Failed to create group"}
except Exception as e:
logger.error(f"Failed to create group: {e}")
return {"success": False, "error": str(e)}
async def list_groups(self) -> Dict[str, Any]:
"""
List all user groups.
Returns:
Dict with list of groups
"""
try:
# Get OpenSearch client
from config.settings import clients
opensearch_client = clients.opensearch
# Ensure index exists
await self._ensure_index_exists(opensearch_client)
# Search for all groups
search_body = {
"query": {"match_all": {}},
"sort": [{"name": {"order": "asc"}}],
"_source": ["group_id", "name", "description", "created_at"],
"size": 1000,
}
result = await opensearch_client.search(
index=GROUPS_INDEX_NAME,
body=search_body,
)
groups = []
for hit in result.get("hits", {}).get("hits", []):
groups.append(hit["_source"])
return {"success": True, "groups": groups}
except Exception as e:
logger.error(f"Failed to list groups: {e}")
return {"success": False, "error": str(e), "groups": []}
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
"""
Get a group by ID.
Args:
group_id: The group ID
Returns:
Group info if found, None otherwise
"""
try:
# Get OpenSearch client
from config.settings import clients
opensearch_client = clients.opensearch
doc = await opensearch_client.get(
index=GROUPS_INDEX_NAME,
id=group_id,
)
return doc["_source"]
except Exception:
return None
async def delete_group(self, group_id: str) -> Dict[str, Any]:
"""
Delete a user group.
Args:
group_id: The group ID to delete
Returns:
Dict with success status
"""
try:
# Get OpenSearch client
from config.settings import clients
opensearch_client = clients.opensearch
# Verify the group exists
try:
doc = await opensearch_client.get(
index=GROUPS_INDEX_NAME,
id=group_id,
)
group_name = doc["_source"].get("name", "unknown")
except Exception:
return {"success": False, "error": "Group not found"}
# Delete the group
result = await opensearch_client.delete(
index=GROUPS_INDEX_NAME,
id=group_id,
refresh="wait_for",
)
if result.get("result") == "deleted":
logger.info(f"Deleted group: {group_name} (id: {group_id})")
return {"success": True}
else:
return {"success": False, "error": "Failed to delete group"}
except Exception as e:
logger.error(f"Failed to delete group: {e}")
return {"success": False, "error": str(e)}

View file

@ -2,7 +2,7 @@ import copy
from typing import Any, Dict
from agentd.tool_decorator import tool
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
from auth_context import get_auth_context
from auth_context import get_auth_context, get_current_user_groups
from utils.logging_config import get_logger
logger = get_logger(__name__)
@ -259,8 +259,24 @@ class SearchService:
# Build query body
if is_wildcard_match_all:
# Match all documents; still allow filters to narrow scope
if filter_clauses:
query_block = {"bool": {"filter": filter_clauses}}
# Also add RBAC group filter for wildcard queries
wildcard_filters = list(filter_clauses) # Copy existing filters
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
wildcard_filters.append(groups_access_filter)
if wildcard_filters:
query_block = {"bool": {"filter": wildcard_filters}}
else:
query_block = {"match_all": {}}
else:
@ -292,6 +308,29 @@ class SearchService:
# Add exists filter to existing filters
all_filters = [*filter_clauses, exists_any_embedding]
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
# Documents are accessible if:
# 1. No allowed_groups field exists (backward compatibility, open access)
# 2. User's groups match any of the document's allowed_groups
user_groups = get_current_user_groups()
if user_groups:
groups_access_filter = {
"bool": {
"should": [
# Document has no allowed_groups restriction
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
# User's groups match document's allowed_groups
{"terms": {"allowed_groups": user_groups}},
],
"minimum_should_match": 1
}
}
all_filters.append(groups_access_filter)
logger.debug(
"Added RBAC group filter",
user_groups=user_groups,
)
logger.debug(
"Building hybrid query with filters",
user_filters_count=len(filter_clauses),

View file

@ -2,8 +2,8 @@ import json
import jwt
import httpx
from datetime import datetime, timedelta
from typing import Dict, Optional, Any
from dataclasses import dataclass, asdict
from typing import Dict, Optional, Any, List
from dataclasses import dataclass, field
from cryptography.hazmat.primitives import serialization
import os
from utils.logging_config import get_logger
@ -24,12 +24,19 @@ class User:
provider: str = "google"
created_at: datetime = None
last_login: datetime = None
# RBAC fields
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
groups: List[str] = field(default_factory=list)
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
if self.last_login is None:
self.last_login = datetime.now()
if self.roles is None:
self.roles = ["openrag_user"]
if self.groups is None:
self.groups = []
class AnonymousUser(User):
"""Anonymous user"""
@ -136,11 +143,31 @@ class SessionManager:
# Create JWT token using the shared method
return self.create_jwt_token(user)
def create_jwt_token(self, user: User) -> str:
"""Create JWT token for an existing user"""
def create_jwt_token(
self,
user: User,
roles: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
expiration_days: int = 7,
) -> str:
"""Create JWT token for an existing user.
Args:
user: The User object to create a token for
roles: Optional roles override (for restricted API key tokens)
groups: Optional groups override (for restricted API key tokens)
expiration_days: Token expiration in days (default 7)
Returns:
Encoded JWT token string
"""
# Use OpenSearch-compatible issuer for OIDC validation
oidc_issuer = "http://openrag-backend:8000"
# Use provided roles/groups or fall back to user's defaults
effective_roles = roles if roles is not None else user.roles
effective_groups = groups if groups is not None else user.groups
# Create JWT token with OIDC-compliant claims
now = datetime.utcnow()
token_payload = {
@ -148,7 +175,7 @@ class SessionManager:
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
"sub": user.user_id, # Subject (user ID)
"aud": ["opensearch", "openrag"], # Audience
"exp": now + timedelta(days=7), # Expiration
"exp": now + timedelta(days=expiration_days), # Expiration
"iat": now, # Issued at
"auth_time": int(now.timestamp()), # Authentication time
# Custom claims
@ -157,7 +184,9 @@ class SessionManager:
"name": user.name,
"preferred_username": user.email,
"email_verified": True,
"roles": ["openrag_user"], # Backend role for OpenSearch
# RBAC claims
"roles": effective_roles, # Backend roles for OpenSearch/tools
"groups": effective_groups, # Group-based access control
}
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")