Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6faa77d5c7 | ||
|
|
ce51628db2 |
21 changed files with 1139 additions and 98 deletions
|
|
@ -6,6 +6,8 @@ import {
|
|||
|
||||
export interface CreateApiKeyRequest {
|
||||
name: string;
|
||||
roles?: string[];
|
||||
groups?: string[];
|
||||
}
|
||||
|
||||
export interface CreateApiKeyResponse {
|
||||
|
|
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
|
|||
name: string;
|
||||
key_prefix: string;
|
||||
created_at: string;
|
||||
roles?: string[];
|
||||
groups?: string[];
|
||||
}
|
||||
|
||||
export const useCreateApiKeyMutation = (
|
||||
|
|
|
|||
61
frontend/app/api/mutations/useCreateGroupMutation.ts
Normal file
61
frontend/app/api/mutations/useCreateGroupMutation.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import {
|
||||
type UseMutationOptions,
|
||||
useMutation,
|
||||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
|
||||
export interface CreateGroupRequest {
|
||||
name: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface CreateGroupResponse {
|
||||
success: boolean;
|
||||
group_id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
created_at: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export const useCreateGroupMutation = (
|
||||
options?: Omit<
|
||||
UseMutationOptions<CreateGroupResponse, Error, CreateGroupRequest>,
|
||||
"mutationFn"
|
||||
>,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function createGroup(
|
||||
variables: CreateGroupRequest,
|
||||
): Promise<CreateGroupResponse> {
|
||||
const response = await fetch("/api/groups", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(variables),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error || "Failed to create group");
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
return useMutation({
|
||||
mutationFn: createGroup,
|
||||
onSuccess: (...args) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["groups"],
|
||||
});
|
||||
options?.onSuccess?.(...args);
|
||||
},
|
||||
onError: options?.onError,
|
||||
onSettled: options?.onSettled,
|
||||
});
|
||||
};
|
||||
|
||||
52
frontend/app/api/mutations/useDeleteGroupMutation.ts
Normal file
52
frontend/app/api/mutations/useDeleteGroupMutation.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import {
|
||||
type UseMutationOptions,
|
||||
useMutation,
|
||||
useQueryClient,
|
||||
} from "@tanstack/react-query";
|
||||
|
||||
export interface DeleteGroupRequest {
|
||||
group_id: string;
|
||||
}
|
||||
|
||||
export interface DeleteGroupResponse {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export const useDeleteGroupMutation = (
|
||||
options?: Omit<
|
||||
UseMutationOptions<DeleteGroupResponse, Error, DeleteGroupRequest>,
|
||||
"mutationFn"
|
||||
>,
|
||||
) => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
async function deleteGroup(
|
||||
variables: DeleteGroupRequest,
|
||||
): Promise<DeleteGroupResponse> {
|
||||
const response = await fetch(`/api/groups/${variables.group_id}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error || "Failed to delete group");
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
return useMutation({
|
||||
mutationFn: deleteGroup,
|
||||
onSuccess: (...args) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["groups"],
|
||||
});
|
||||
options?.onSuccess?.(...args);
|
||||
},
|
||||
onError: options?.onError,
|
||||
onSettled: options?.onSettled,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -6,6 +6,8 @@ export interface ApiKey {
|
|||
key_prefix: string;
|
||||
created_at: string;
|
||||
last_used_at: string | null;
|
||||
roles?: string[];
|
||||
groups?: string[];
|
||||
}
|
||||
|
||||
export interface GetApiKeysResponse {
|
||||
|
|
|
|||
32
frontend/app/api/queries/useGetGroupsQuery.ts
Normal file
32
frontend/app/api/queries/useGetGroupsQuery.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
import { type UseQueryOptions, useQuery } from "@tanstack/react-query";
|
||||
|
||||
export interface Group {
|
||||
group_id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export interface GetGroupsResponse {
|
||||
success: boolean;
|
||||
groups: Group[];
|
||||
}
|
||||
|
||||
export const useGetGroupsQuery = (
|
||||
options?: Omit<UseQueryOptions<GetGroupsResponse>, "queryKey" | "queryFn">,
|
||||
) => {
|
||||
async function getGroups(): Promise<GetGroupsResponse> {
|
||||
const response = await fetch("/api/groups");
|
||||
if (response.ok) {
|
||||
return await response.json();
|
||||
}
|
||||
throw new Error("Failed to fetch groups");
|
||||
}
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["groups"],
|
||||
queryFn: getGroups,
|
||||
...options,
|
||||
});
|
||||
};
|
||||
|
||||
|
|
@ -12,8 +12,11 @@ import {
|
|||
useGetOpenAIModelsQuery,
|
||||
} from "@/app/api/queries/useGetModelsQuery";
|
||||
import { useGetApiKeysQuery } from "@/app/api/queries/useGetApiKeysQuery";
|
||||
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
|
||||
import { useCreateApiKeyMutation } from "@/app/api/mutations/useCreateApiKeyMutation";
|
||||
import { useRevokeApiKeyMutation } from "@/app/api/mutations/useRevokeApiKeyMutation";
|
||||
import { ManageGroupsModal } from "@/components/ManageGroupsModal";
|
||||
import { MultiSelect } from "@/components/ui/multi-select";
|
||||
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
||||
import { ConfirmationDialog } from "@/components/confirmation-dialog";
|
||||
import {
|
||||
|
|
@ -136,8 +139,10 @@ function KnowledgeSourcesPage() {
|
|||
// API Keys state
|
||||
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
|
||||
const [newKeyName, setNewKeyName] = useState("");
|
||||
const [newKeyGroups, setNewKeyGroups] = useState<string[]>([]);
|
||||
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
|
||||
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
|
||||
const [manageGroupsOpen, setManageGroupsOpen] = useState(false);
|
||||
|
||||
// Fetch settings using React Query
|
||||
const { data: settings = {} } = useGetSettingsQuery({
|
||||
|
|
@ -149,6 +154,11 @@ function KnowledgeSourcesPage() {
|
|||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
|
||||
// Fetch user groups
|
||||
const { data: groupsData } = useGetGroupsQuery({
|
||||
enabled: isAuthenticated || isNoAuthMode,
|
||||
});
|
||||
|
||||
// API key mutations
|
||||
const createApiKeyMutation = useCreateApiKeyMutation({
|
||||
onSuccess: (data) => {
|
||||
|
|
@ -156,6 +166,7 @@ function KnowledgeSourcesPage() {
|
|||
setCreateKeyDialogOpen(false);
|
||||
setShowKeyDialogOpen(true);
|
||||
setNewKeyName("");
|
||||
setNewKeyGroups([]);
|
||||
toast.success("API key created");
|
||||
},
|
||||
onError: (error) => {
|
||||
|
|
@ -438,7 +449,11 @@ function KnowledgeSourcesPage() {
|
|||
toast.error("Please enter a name for the API key");
|
||||
return;
|
||||
}
|
||||
createApiKeyMutation.mutate({ name: newKeyName.trim() });
|
||||
// Use selected groups directly (already an array)
|
||||
createApiKeyMutation.mutate({
|
||||
name: newKeyName.trim(),
|
||||
groups: newKeyGroups.length > 0 ? newKeyGroups : undefined,
|
||||
});
|
||||
};
|
||||
|
||||
const handleRevokeApiKey = (keyId: string) => {
|
||||
|
|
@ -1426,6 +1441,9 @@ function KnowledgeSourcesPage() {
|
|||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||
Key
|
||||
</th>
|
||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||
Groups
|
||||
</th>
|
||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||
Created
|
||||
</th>
|
||||
|
|
@ -1448,6 +1466,24 @@ function KnowledgeSourcesPage() {
|
|||
{key.key_prefix}...
|
||||
</code>
|
||||
</td>
|
||||
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||
{key.groups && key.groups.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{key.groups.map((group: string) => (
|
||||
<span
|
||||
key={group}
|
||||
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
|
||||
>
|
||||
{group}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-muted-foreground/50">
|
||||
All groups
|
||||
</span>
|
||||
)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||
{formatDate(key.created_at)}
|
||||
</td>
|
||||
|
|
@ -1507,58 +1543,88 @@ function KnowledgeSourcesPage() {
|
|||
</Card>
|
||||
)}
|
||||
|
||||
{/* Create API Key Dialog */}
|
||||
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create API Key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Give your API key a name to help you identify it later.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="py-4">
|
||||
<LabelWrapper label="Name" id="api-key-name">
|
||||
<Input
|
||||
id="api-key-name"
|
||||
placeholder="e.g., Production App, Development"
|
||||
value={newKeyName}
|
||||
onChange={(e) => setNewKeyName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
handleCreateApiKey();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</LabelWrapper>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={() => {
|
||||
setCreateKeyDialogOpen(false);
|
||||
setNewKeyName("");
|
||||
{/* Create API Key Dialog */}
|
||||
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create API Key</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create an API key with optional group restrictions for access
|
||||
control.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="py-4 space-y-4">
|
||||
<LabelWrapper label="Name" id="api-key-name">
|
||||
<Input
|
||||
id="api-key-name"
|
||||
placeholder="e.g., Production App, Development"
|
||||
value={newKeyName}
|
||||
onChange={(e) => setNewKeyName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
handleCreateApiKey();
|
||||
}
|
||||
}}
|
||||
size="sm"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreateApiKey}
|
||||
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
||||
size="sm"
|
||||
>
|
||||
{createApiKeyMutation.isPending ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
</>
|
||||
) : (
|
||||
"Create Key"
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
/>
|
||||
</LabelWrapper>
|
||||
<div className="space-y-2">
|
||||
<LabelWrapper
|
||||
label="Groups (optional)"
|
||||
id="api-key-groups"
|
||||
helperText="Restrict this key to specific groups"
|
||||
>
|
||||
<MultiSelect
|
||||
options={(groupsData?.groups || []).map((g) => ({
|
||||
value: g.name,
|
||||
label: g.name,
|
||||
}))}
|
||||
value={newKeyGroups}
|
||||
onValueChange={setNewKeyGroups}
|
||||
placeholder="Select groups..."
|
||||
showAllOption={false}
|
||||
searchPlaceholder="Search groups..."
|
||||
/>
|
||||
</LabelWrapper>
|
||||
<Button
|
||||
variant="link"
|
||||
size="sm"
|
||||
className="h-auto p-0 text-xs"
|
||||
onClick={() => setManageGroupsOpen(true)}
|
||||
>
|
||||
<Plus className="h-3 w-3 mr-1" />
|
||||
Manage Groups
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={() => {
|
||||
setCreateKeyDialogOpen(false);
|
||||
setNewKeyName("");
|
||||
setNewKeyGroups([]);
|
||||
}}
|
||||
size="sm"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreateApiKey}
|
||||
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
||||
size="sm"
|
||||
>
|
||||
{createApiKeyMutation.isPending ? (
|
||||
<>
|
||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||
Creating...
|
||||
</>
|
||||
) : (
|
||||
"Create Key"
|
||||
)}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
{/* Show Created API Key Dialog */}
|
||||
<Dialog
|
||||
|
|
@ -1593,6 +1659,12 @@ function KnowledgeSourcesPage() {
|
|||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
{/* Manage Groups Modal */}
|
||||
<ManageGroupsModal
|
||||
open={manageGroupsOpen}
|
||||
onOpenChange={setManageGroupsOpen}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
175
frontend/components/ManageGroupsModal.tsx
Normal file
175
frontend/components/ManageGroupsModal.tsx
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
|
||||
import { useCreateGroupMutation } from "@/app/api/mutations/useCreateGroupMutation";
|
||||
import { useDeleteGroupMutation } from "@/app/api/mutations/useDeleteGroupMutation";
|
||||
import { Plus, Trash2, Loader2, Users } from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface ManageGroupsModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export function ManageGroupsModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
}: ManageGroupsModalProps) {
|
||||
const [newGroupName, setNewGroupName] = useState("");
|
||||
const [newGroupDescription, setNewGroupDescription] = useState("");
|
||||
|
||||
const { data: groupsData, isLoading: groupsLoading } = useGetGroupsQuery({
|
||||
enabled: open,
|
||||
});
|
||||
|
||||
const createGroupMutation = useCreateGroupMutation({
|
||||
onSuccess: () => {
|
||||
setNewGroupName("");
|
||||
setNewGroupDescription("");
|
||||
toast.success("Group created successfully");
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.error("Failed to create group", { description: error.message });
|
||||
},
|
||||
});
|
||||
|
||||
const deleteGroupMutation = useDeleteGroupMutation({
|
||||
onSuccess: () => {
|
||||
toast.success("Group deleted successfully");
|
||||
},
|
||||
onError: (error) => {
|
||||
toast.error("Failed to delete group", { description: error.message });
|
||||
},
|
||||
});
|
||||
|
||||
const handleCreateGroup = () => {
|
||||
if (!newGroupName.trim()) {
|
||||
toast.error("Please enter a group name");
|
||||
return;
|
||||
}
|
||||
createGroupMutation.mutate({
|
||||
name: newGroupName.trim(),
|
||||
description: newGroupDescription.trim(),
|
||||
});
|
||||
};
|
||||
|
||||
const handleDeleteGroup = (groupId: string, groupName: string) => {
|
||||
if (confirm(`Are you sure you want to delete the group "${groupName}"?`)) {
|
||||
deleteGroupMutation.mutate({ group_id: groupId });
|
||||
}
|
||||
};
|
||||
|
||||
const groups = groupsData?.groups || [];
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Manage User Groups</DialogTitle>
|
||||
<DialogDescription>
|
||||
Create and manage user groups for access control. Groups can be
|
||||
assigned to API keys to restrict document access.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4 py-4">
|
||||
{/* Add new group section */}
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Create New Group</label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
placeholder="Group name"
|
||||
value={newGroupName}
|
||||
onChange={(e) => setNewGroupName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") {
|
||||
handleCreateGroup();
|
||||
}
|
||||
}}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button
|
||||
onClick={handleCreateGroup}
|
||||
disabled={
|
||||
createGroupMutation.isPending || !newGroupName.trim()
|
||||
}
|
||||
size="sm"
|
||||
>
|
||||
{createGroupMutation.isPending ? (
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
) : (
|
||||
<Plus className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
<Input
|
||||
placeholder="Description (optional)"
|
||||
value={newGroupDescription}
|
||||
onChange={(e) => setNewGroupDescription(e.target.value)}
|
||||
className="text-sm"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Existing groups list */}
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Existing Groups</label>
|
||||
{groupsLoading ? (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<Loader2 className="h-5 w-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
) : groups.length > 0 ? (
|
||||
<div className="border rounded-lg divide-y max-h-48 overflow-y-auto">
|
||||
{groups.map((group) => (
|
||||
<div
|
||||
key={group.group_id}
|
||||
className="flex items-center justify-between px-3 py-2 hover:bg-muted/50"
|
||||
>
|
||||
<div className="flex-1 min-w-0">
|
||||
<p className="text-sm font-medium truncate">
|
||||
{group.name}
|
||||
</p>
|
||||
{group.description && (
|
||||
<p className="text-xs text-muted-foreground truncate">
|
||||
{group.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="h-8 w-8 p-0 text-destructive hover:text-destructive hover:bg-destructive/10"
|
||||
onClick={() =>
|
||||
handleDeleteGroup(group.group_id, group.name)
|
||||
}
|
||||
disabled={deleteGroupMutation.isPending}
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-center py-6 border rounded-lg">
|
||||
<Users className="h-8 w-8 mx-auto text-muted-foreground/50 mb-2" />
|
||||
<p className="text-sm text-muted-foreground">
|
||||
No groups yet. Create one above.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
set_search_limit(limit)
|
||||
set_score_threshold(score_threshold)
|
||||
|
||||
# Get RBAC groups and roles from user for access control
|
||||
user_groups = getattr(user, "groups", [])
|
||||
user_roles = getattr(user, "roles", [])
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
await chat_service.chat(
|
||||
|
|
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=previous_response_id,
|
||||
stream=True,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=previous_response_id,
|
||||
stream=False,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
|
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
|
||||
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||
|
||||
# Get RBAC groups and roles from user for access control
|
||||
user_groups = getattr(user, "groups", [])
|
||||
user_roles = getattr(user, "roles", [])
|
||||
|
||||
if not prompt:
|
||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||
|
||||
|
|
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=previous_response_id,
|
||||
stream=True,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=previous_response_id,
|
||||
stream=False,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
)
|
||||
return JSONResponse(result)
|
||||
|
||||
|
|
|
|||
114
src/api/groups.py
Normal file
114
src/api/groups.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
User Groups management endpoints.
|
||||
|
||||
These endpoints allow managing user groups for RBAC.
|
||||
"""
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def list_groups_endpoint(request: Request, group_service):
|
||||
"""
|
||||
List all user groups.
|
||||
|
||||
GET /groups
|
||||
|
||||
Response:
|
||||
{
|
||||
"success": true,
|
||||
"groups": [
|
||||
{
|
||||
"group_id": "...",
|
||||
"name": "finance",
|
||||
"description": "Finance team",
|
||||
"created_at": "2024-01-01T00:00:00"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
result = await group_service.list_groups()
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
async def create_group_endpoint(request: Request, group_service):
|
||||
"""
|
||||
Create a new user group.
|
||||
|
||||
POST /groups
|
||||
Body: {"name": "finance", "description": "Finance team"}
|
||||
|
||||
Response:
|
||||
{
|
||||
"success": true,
|
||||
"group_id": "...",
|
||||
"name": "finance",
|
||||
"description": "Finance team",
|
||||
"created_at": "2024-01-01T00:00:00"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
name = data.get("name", "").strip()
|
||||
description = data.get("description", "").strip()
|
||||
|
||||
if not name:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "Name is required"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if len(name) > 100:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "Name must be 100 characters or less"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
result = await group_service.create_group(
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return JSONResponse(result)
|
||||
elif "already exists" in result.get("error", ""):
|
||||
return JSONResponse(result, status_code=409)
|
||||
else:
|
||||
return JSONResponse(result, status_code=500)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create group: {e}")
|
||||
return JSONResponse(
|
||||
{"success": False, "error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
async def delete_group_endpoint(request: Request, group_service):
|
||||
"""
|
||||
Delete a user group.
|
||||
|
||||
DELETE /groups/{group_id}
|
||||
|
||||
Response:
|
||||
{"success": true}
|
||||
"""
|
||||
group_id = request.path_params.get("group_id")
|
||||
|
||||
if not group_id:
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "Group ID is required"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
result = await group_service.delete_group(group_id=group_id)
|
||||
|
||||
if result.get("success"):
|
||||
return JSONResponse(result)
|
||||
elif result.get("error") == "Group not found":
|
||||
return JSONResponse(result, status_code=404)
|
||||
else:
|
||||
return JSONResponse(result, status_code=500)
|
||||
|
||||
|
|
@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
|
|||
|
||||
async def create_key_endpoint(request: Request, api_key_service):
|
||||
"""
|
||||
Create a new API key for the authenticated user.
|
||||
Create a new API key for the authenticated user with optional RBAC restrictions.
|
||||
|
||||
POST /keys
|
||||
Body: {"name": "My API Key"}
|
||||
Body: {
|
||||
"name": "My API Key",
|
||||
"roles": ["openrag_user"], // Optional: restrict key to specific roles
|
||||
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
|
|
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
|
|||
"key_id": "...",
|
||||
"key_prefix": "orag_abc12345",
|
||||
"name": "My API Key",
|
||||
"roles": ["openrag_user"],
|
||||
"groups": ["finance", "hr"],
|
||||
"created_at": "2024-01-01T00:00:00",
|
||||
"api_key": "orag_abc12345..." // Full key, only shown once!
|
||||
}
|
||||
|
|
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
|
|||
status_code=400,
|
||||
)
|
||||
|
||||
# Extract optional RBAC fields
|
||||
roles = data.get("roles")
|
||||
groups = data.get("groups")
|
||||
|
||||
# Validate roles and groups are lists if provided
|
||||
if roles is not None and not isinstance(roles, list):
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "roles must be a list"},
|
||||
status_code=400,
|
||||
)
|
||||
if groups is not None and not isinstance(groups, list):
|
||||
return JSONResponse(
|
||||
{"success": False, "error": "groups must be a list"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
result = await api_key_service.create_key(
|
||||
user_id=user_id,
|
||||
user_email=user_email,
|
||||
name=name,
|
||||
jwt_token=jwt_token,
|
||||
roles=roles,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
|
|
|
|||
|
|
@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
|||
|
||||
user = request.state.user
|
||||
user_id = user.user_id
|
||||
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
|
||||
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||
|
||||
# Get RBAC groups and roles from user for access control
|
||||
user_groups = getattr(user, "groups", [])
|
||||
user_roles = getattr(user, "roles", [])
|
||||
|
||||
# Set context variables for search tool
|
||||
if filters:
|
||||
set_search_filters(filters)
|
||||
set_search_limit(limit)
|
||||
set_score_threshold(score_threshold)
|
||||
set_auth_context(user_id, jwt_token)
|
||||
set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
|
||||
|
||||
if stream:
|
||||
raw_stream = await chat_service.langflow_chat(
|
||||
|
|
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=chat_id,
|
||||
stream=True,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
)
|
||||
chat_id_container = {}
|
||||
return StreamingResponse(
|
||||
|
|
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
|||
previous_response_id=chat_id,
|
||||
stream=False,
|
||||
filter_id=filter_id,
|
||||
groups=user_groups,
|
||||
roles=user_roles,
|
||||
)
|
||||
# Transform response_id to chat_id for v1 API format
|
||||
return JSONResponse({
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
"""
|
||||
API Key middleware for authenticating public API requests.
|
||||
|
||||
This middleware validates API keys and generates ephemeral JWTs with the
|
||||
key's specific roles and groups for downstream security enforcement.
|
||||
"""
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
|
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def require_api_key(api_key_service):
|
||||
def require_api_key(api_key_service, session_manager=None):
|
||||
"""
|
||||
Decorator to require API key authentication for public API endpoints.
|
||||
|
||||
Generates an ephemeral JWT with the API key's specific roles and groups
|
||||
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
|
||||
|
||||
Usage:
|
||||
@require_api_key(api_key_service)
|
||||
@require_api_key(api_key_service, session_manager)
|
||||
async def my_endpoint(request):
|
||||
user = request.state.user
|
||||
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
|
||||
...
|
||||
"""
|
||||
|
||||
|
|
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
|
|||
status_code=401,
|
||||
)
|
||||
|
||||
# Validate the key
|
||||
# Validate the key and get RBAC claims
|
||||
user_info = await api_key_service.validate_key(api_key)
|
||||
|
||||
if not user_info:
|
||||
|
|
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
|
|||
status_code=401,
|
||||
)
|
||||
|
||||
# Create a User object from the API key info
|
||||
# Extract RBAC fields from API key
|
||||
key_roles = user_info.get("roles", ["openrag_user"])
|
||||
key_groups = user_info.get("groups", [])
|
||||
|
||||
# Create a User object with the API key's roles and groups
|
||||
user = User(
|
||||
user_id=user_info["user_id"],
|
||||
email=user_info["user_email"],
|
||||
name=user_info.get("name", "API User"),
|
||||
picture=None,
|
||||
provider="api_key",
|
||||
roles=key_roles,
|
||||
groups=key_groups,
|
||||
)
|
||||
|
||||
# Set request state
|
||||
request.state.user = user
|
||||
request.state.api_key_id = user_info["key_id"]
|
||||
request.state.jwt_token = None # No JWT for API key auth
|
||||
|
||||
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||
if session_manager:
|
||||
# Create a short-lived JWT with the key's specific permissions
|
||||
ephemeral_jwt = session_manager.create_jwt_token(
|
||||
user=user,
|
||||
roles=key_roles,
|
||||
groups=key_groups,
|
||||
expiration_days=1, # Short-lived for API requests
|
||||
)
|
||||
request.state.jwt_token = ephemeral_jwt
|
||||
logger.debug(
|
||||
"Generated ephemeral JWT for API key",
|
||||
key_id=user_info["key_id"],
|
||||
roles=key_roles,
|
||||
groups=key_groups,
|
||||
)
|
||||
else:
|
||||
request.state.jwt_token = None
|
||||
logger.warning(
|
||||
"No session_manager provided - JWT not generated for API key"
|
||||
)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
|
|
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
|
|||
return decorator
|
||||
|
||||
|
||||
def optional_api_key(api_key_service):
|
||||
def optional_api_key(api_key_service, session_manager=None):
|
||||
"""
|
||||
Decorator to optionally authenticate with API key.
|
||||
Sets request.state.user to None if no valid API key is provided.
|
||||
|
||||
When a valid API key is provided, generates an ephemeral JWT with
|
||||
the key's specific roles and groups.
|
||||
"""
|
||||
|
||||
def decorator(handler):
|
||||
|
|
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
|
|||
api_key = _extract_api_key(request)
|
||||
|
||||
if api_key:
|
||||
# Validate the key
|
||||
# Validate the key and get RBAC claims
|
||||
user_info = await api_key_service.validate_key(api_key)
|
||||
|
||||
if user_info:
|
||||
# Create a User object from the API key info
|
||||
# Extract RBAC fields from API key
|
||||
key_roles = user_info.get("roles", ["openrag_user"])
|
||||
key_groups = user_info.get("groups", [])
|
||||
|
||||
# Create a User object with the API key's roles and groups
|
||||
user = User(
|
||||
user_id=user_info["user_id"],
|
||||
email=user_info["user_email"],
|
||||
name=user_info.get("name", "API User"),
|
||||
picture=None,
|
||||
provider="api_key",
|
||||
roles=key_roles,
|
||||
groups=key_groups,
|
||||
)
|
||||
request.state.user = user
|
||||
request.state.api_key_id = user_info["key_id"]
|
||||
request.state.jwt_token = None
|
||||
|
||||
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||
if session_manager:
|
||||
ephemeral_jwt = session_manager.create_jwt_token(
|
||||
user=user,
|
||||
roles=key_roles,
|
||||
groups=key_groups,
|
||||
expiration_days=1,
|
||||
)
|
||||
request.state.jwt_token = ephemeral_jwt
|
||||
else:
|
||||
request.state.jwt_token = None
|
||||
else:
|
||||
request.state.user = None
|
||||
request.state.api_key_id = None
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
|
|||
"""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
# Context variables for current request authentication
|
||||
_current_user_id: ContextVar[Optional[str]] = ContextVar(
|
||||
|
|
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
|
|||
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
|
||||
"current_jwt_token", default=None
|
||||
)
|
||||
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
|
||||
"current_user_groups", default=None
|
||||
)
|
||||
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
|
||||
"current_user_roles", default=None
|
||||
)
|
||||
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
||||
"current_search_filters", default=None
|
||||
)
|
||||
|
|
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
|
|||
)
|
||||
|
||||
|
||||
def set_auth_context(user_id: str, jwt_token: str):
|
||||
"""Set authentication context for the current async context"""
|
||||
def set_auth_context(
|
||||
user_id: str,
|
||||
jwt_token: str,
|
||||
groups: Optional[List[str]] = None,
|
||||
roles: Optional[List[str]] = None,
|
||||
):
|
||||
"""Set authentication context for the current async context
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
jwt_token: The JWT token for authentication
|
||||
groups: Optional list of groups the user belongs to (for RBAC)
|
||||
roles: Optional list of roles the user has (for RBAC)
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_jwt_token.set(jwt_token)
|
||||
_current_user_groups.set(groups or [])
|
||||
_current_user_roles.set(roles or [])
|
||||
|
||||
|
||||
def get_current_user_id() -> Optional[str]:
|
||||
|
|
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
|
|||
return _current_jwt_token.get()
|
||||
|
||||
|
||||
def get_current_user_groups() -> List[str]:
|
||||
"""Get current user's groups from context (for RBAC)"""
|
||||
return _current_user_groups.get() or []
|
||||
|
||||
|
||||
def get_current_user_roles() -> List[str]:
|
||||
"""Get current user's roles from context (for RBAC)"""
|
||||
return _current_user_roles.get() or []
|
||||
|
||||
|
||||
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
||||
"""Get current authentication context (user_id, jwt_token)"""
|
||||
return _current_user_id.get(), _current_jwt_token.get()
|
||||
|
|
|
|||
|
|
@ -165,6 +165,23 @@ API_KEYS_INDEX_BODY = {
|
|||
},
|
||||
}
|
||||
|
||||
# User Groups index for RBAC management
|
||||
GROUPS_INDEX_NAME = "openrag_groups"
|
||||
GROUPS_INDEX_BODY = {
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"group_id": {"type": "keyword"},
|
||||
"name": {"type": "keyword"},
|
||||
"description": {"type": "text"},
|
||||
"created_at": {"type": "date"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Convenience base URL for Langflow REST API
|
||||
LANGFLOW_BASE_URL = f"{LANGFLOW_URL}/api/v1"
|
||||
|
||||
|
|
|
|||
69
src/main.py
69
src/main.py
|
|
@ -58,6 +58,10 @@ from auth_middleware import optional_auth, require_auth
|
|||
from api_key_middleware import require_api_key
|
||||
from services.api_key_service import APIKeyService
|
||||
from api import keys as api_keys
|
||||
|
||||
# User Groups management
|
||||
from services.group_service import GroupService
|
||||
from api import groups as api_groups
|
||||
from api.v1 import chat as v1_chat, search as v1_search, documents as v1_documents, settings as v1_settings, knowledge_filters as v1_knowledge_filters
|
||||
|
||||
# Configuration and setup
|
||||
|
|
@ -665,6 +669,9 @@ async def initialize_services():
|
|||
# API Key service for public API authentication
|
||||
api_key_service = APIKeyService(session_manager)
|
||||
|
||||
# Group service for RBAC management
|
||||
group_service = GroupService(session_manager)
|
||||
|
||||
return {
|
||||
"document_service": document_service,
|
||||
"search_service": search_service,
|
||||
|
|
@ -679,6 +686,7 @@ async def initialize_services():
|
|||
"monitor_service": monitor_service,
|
||||
"session_manager": session_manager,
|
||||
"api_key_service": api_key_service,
|
||||
"group_service": group_service,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -1310,11 +1318,42 @@ async def create_app():
|
|||
),
|
||||
methods=["DELETE"],
|
||||
),
|
||||
# ===== User Groups Management Endpoints (JWT auth for UI) =====
|
||||
Route(
|
||||
"/groups",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
api_groups.list_groups_endpoint,
|
||||
group_service=services["group_service"],
|
||||
)
|
||||
),
|
||||
methods=["GET"],
|
||||
),
|
||||
Route(
|
||||
"/groups",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
api_groups.create_group_endpoint,
|
||||
group_service=services["group_service"],
|
||||
)
|
||||
),
|
||||
methods=["POST"],
|
||||
),
|
||||
Route(
|
||||
"/groups/{group_id}",
|
||||
require_auth(services["session_manager"])(
|
||||
partial(
|
||||
api_groups.delete_group_endpoint,
|
||||
group_service=services["group_service"],
|
||||
)
|
||||
),
|
||||
methods=["DELETE"],
|
||||
),
|
||||
# ===== Public API v1 Endpoints (API Key auth) =====
|
||||
# Chat endpoints
|
||||
Route(
|
||||
"/v1/chat",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_chat.chat_create_endpoint,
|
||||
chat_service=services["chat_service"],
|
||||
|
|
@ -1325,7 +1364,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/chat",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_chat.chat_list_endpoint,
|
||||
chat_service=services["chat_service"],
|
||||
|
|
@ -1336,7 +1375,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/chat/{chat_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_chat.chat_get_endpoint,
|
||||
chat_service=services["chat_service"],
|
||||
|
|
@ -1347,7 +1386,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/chat/{chat_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_chat.chat_delete_endpoint,
|
||||
chat_service=services["chat_service"],
|
||||
|
|
@ -1359,7 +1398,7 @@ async def create_app():
|
|||
# Search endpoint
|
||||
Route(
|
||||
"/v1/search",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_search.search_endpoint,
|
||||
search_service=services["search_service"],
|
||||
|
|
@ -1371,7 +1410,7 @@ async def create_app():
|
|||
# Documents endpoints
|
||||
Route(
|
||||
"/v1/documents/ingest",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_documents.ingest_endpoint,
|
||||
document_service=services["document_service"],
|
||||
|
|
@ -1384,7 +1423,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/tasks/{task_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_documents.task_status_endpoint,
|
||||
task_service=services["task_service"],
|
||||
|
|
@ -1395,7 +1434,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/documents",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_documents.delete_document_endpoint,
|
||||
document_service=services["document_service"],
|
||||
|
|
@ -1407,14 +1446,14 @@ async def create_app():
|
|||
# Settings endpoints
|
||||
Route(
|
||||
"/v1/settings",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(v1_settings.get_settings_endpoint)
|
||||
),
|
||||
methods=["GET"],
|
||||
),
|
||||
Route(
|
||||
"/v1/settings",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_settings.update_settings_endpoint,
|
||||
session_manager=services["session_manager"],
|
||||
|
|
@ -1425,7 +1464,7 @@ async def create_app():
|
|||
# Knowledge filters endpoints
|
||||
Route(
|
||||
"/v1/knowledge-filters",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_knowledge_filters.create_endpoint,
|
||||
knowledge_filter_service=services["knowledge_filter_service"],
|
||||
|
|
@ -1436,7 +1475,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/knowledge-filters/search",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_knowledge_filters.search_endpoint,
|
||||
knowledge_filter_service=services["knowledge_filter_service"],
|
||||
|
|
@ -1447,7 +1486,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/knowledge-filters/{filter_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_knowledge_filters.get_endpoint,
|
||||
knowledge_filter_service=services["knowledge_filter_service"],
|
||||
|
|
@ -1458,7 +1497,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/knowledge-filters/{filter_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_knowledge_filters.update_endpoint,
|
||||
knowledge_filter_service=services["knowledge_filter_service"],
|
||||
|
|
@ -1469,7 +1508,7 @@ async def create_app():
|
|||
),
|
||||
Route(
|
||||
"/v1/knowledge-filters/{filter_id}",
|
||||
require_api_key(services["api_key_service"])(
|
||||
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||
partial(
|
||||
v1_knowledge_filters.delete_endpoint,
|
||||
knowledge_filter_service=services["knowledge_filter_service"],
|
||||
|
|
|
|||
|
|
@ -158,6 +158,7 @@ class TaskProcessor:
|
|||
connector_type: str = "local",
|
||||
embedding_model: str = None,
|
||||
is_sample_data: bool = False,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
"""
|
||||
Standard processing pipeline for non-Langflow processors:
|
||||
|
|
@ -166,6 +167,8 @@ class TaskProcessor:
|
|||
Args:
|
||||
embedding_model: Embedding model to use (defaults to the current
|
||||
embedding model from settings)
|
||||
allowed_groups: List of groups that can access this document (RBAC).
|
||||
Empty list or None means no group restrictions.
|
||||
"""
|
||||
import datetime
|
||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||
|
|
@ -259,6 +262,11 @@ class TaskProcessor:
|
|||
if owner_email is not None:
|
||||
chunk_doc["owner_email"] = owner_email
|
||||
|
||||
# RBAC: Set allowed groups for access control
|
||||
# If allowed_groups is provided and non-empty, store it for DLS filtering
|
||||
if allowed_groups:
|
||||
chunk_doc["allowed_groups"] = allowed_groups
|
||||
|
||||
# Mark as sample data if specified
|
||||
if is_sample_data:
|
||||
chunk_doc["is_sample_data"] = "true"
|
||||
|
|
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
|
|||
owner_name: str = None,
|
||||
owner_email: str = None,
|
||||
is_sample_data: bool = False,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
super().__init__(document_service)
|
||||
self.owner_user_id = owner_user_id
|
||||
|
|
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
|
|||
self.owner_name = owner_name
|
||||
self.owner_email = owner_email
|
||||
self.is_sample_data = is_sample_data
|
||||
self.allowed_groups = allowed_groups
|
||||
|
||||
async def process_item(
|
||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||
|
|
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
|
|||
file_size=file_size,
|
||||
connector_type="local",
|
||||
is_sample_data=self.is_sample_data,
|
||||
allowed_groups=self.allowed_groups,
|
||||
)
|
||||
|
||||
file_task.status = TaskStatus.COMPLETED
|
||||
|
|
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
owner_name: str = None,
|
||||
owner_email: str = None,
|
||||
document_service=None,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
super().__init__(document_service=document_service)
|
||||
self.connector_service = connector_service
|
||||
|
|
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
self.jwt_token = jwt_token
|
||||
self.owner_name = owner_name
|
||||
self.owner_email = owner_email
|
||||
self.allowed_groups = allowed_groups
|
||||
|
||||
async def process_item(
|
||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||
|
|
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
|||
owner_email=self.owner_email,
|
||||
file_size=len(document.content),
|
||||
connector_type=connection.connector_type,
|
||||
allowed_groups=self.allowed_groups,
|
||||
)
|
||||
|
||||
# Add connector-specific metadata
|
||||
|
|
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
|||
jwt_token: str = None,
|
||||
owner_name: str = None,
|
||||
owner_email: str = None,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.langflow_connector_service = langflow_connector_service
|
||||
|
|
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
|||
self.jwt_token = jwt_token
|
||||
self.owner_name = owner_name
|
||||
self.owner_email = owner_email
|
||||
self.allowed_groups = allowed_groups
|
||||
|
||||
async def process_item(
|
||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||
|
|
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
jwt_token: str = None,
|
||||
owner_name: str = None,
|
||||
owner_email: str = None,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
import boto3
|
||||
|
||||
|
|
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
self.jwt_token = jwt_token
|
||||
self.owner_name = owner_name
|
||||
self.owner_email = owner_email
|
||||
self.allowed_groups = allowed_groups
|
||||
|
||||
async def process_item(
|
||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||
|
|
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
|
|||
owner_email=self.owner_email,
|
||||
file_size=file_size,
|
||||
connector_type="s3",
|
||||
allowed_groups=self.allowed_groups,
|
||||
)
|
||||
|
||||
result["path"] = f"s3://{self.bucket}/{item}"
|
||||
|
|
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
|
|||
settings: dict = None,
|
||||
delete_after_ingest: bool = True,
|
||||
replace_duplicates: bool = False,
|
||||
allowed_groups: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.langflow_file_service = langflow_file_service
|
||||
|
|
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
|
|||
self.settings = settings
|
||||
self.delete_after_ingest = delete_after_ingest
|
||||
self.replace_duplicates = replace_duplicates
|
||||
self.allowed_groups = allowed_groups
|
||||
|
||||
async def process_item(
|
||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||
|
|
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
|
|||
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
|
||||
# Mark as local upload for connector_type
|
||||
metadata_tweaks.append({"key": "connector_type", "value": "local"})
|
||||
# RBAC: Add allowed_groups for access control
|
||||
if self.allowed_groups:
|
||||
# Store as comma-separated string for Langflow metadata
|
||||
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
|
||||
|
||||
if metadata_tweaks:
|
||||
# Initialize the OpenSearch component tweaks if not already present
|
||||
|
|
|
|||
|
|
@ -52,15 +52,19 @@ class APIKeyService:
|
|||
user_email: str,
|
||||
name: str,
|
||||
jwt_token: str = None,
|
||||
roles: List[str] = None,
|
||||
groups: List[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new API key for a user.
|
||||
Create a new API key for a user with optional RBAC restrictions.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID
|
||||
user_email: The user's email
|
||||
name: A friendly name for the key
|
||||
jwt_token: JWT token for OpenSearch authentication
|
||||
roles: Optional list of roles to restrict this key to
|
||||
groups: Optional list of groups this key can access
|
||||
|
||||
Returns:
|
||||
Dict with success status, key info, and the full key (only shown once)
|
||||
|
|
@ -74,6 +78,14 @@ class APIKeyService:
|
|||
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
# Default roles if not specified
|
||||
if roles is None:
|
||||
roles = ["openrag_user"]
|
||||
|
||||
# Default groups to empty if not specified
|
||||
if groups is None:
|
||||
groups = []
|
||||
|
||||
# Create the document to store
|
||||
key_doc = {
|
||||
"key_id": key_id,
|
||||
|
|
@ -85,6 +97,9 @@ class APIKeyService:
|
|||
"created_at": now,
|
||||
"last_used_at": None,
|
||||
"revoked": False,
|
||||
# RBAC fields
|
||||
"roles": roles,
|
||||
"groups": groups,
|
||||
}
|
||||
|
||||
# Get OpenSearch client
|
||||
|
|
@ -105,6 +120,8 @@ class APIKeyService:
|
|||
user_id=user_id,
|
||||
key_id=key_id,
|
||||
key_prefix=key_prefix,
|
||||
roles=roles,
|
||||
groups=groups,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -112,6 +129,8 @@ class APIKeyService:
|
|||
"key_prefix": key_prefix,
|
||||
"name": name,
|
||||
"created_at": now,
|
||||
"roles": roles,
|
||||
"groups": groups,
|
||||
"api_key": full_key, # Only returned once!
|
||||
}
|
||||
else:
|
||||
|
|
@ -123,13 +142,13 @@ class APIKeyService:
|
|||
|
||||
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validate an API key and return user info if valid.
|
||||
Validate an API key and return user info with RBAC claims if valid.
|
||||
|
||||
Args:
|
||||
api_key: The full API key to validate
|
||||
|
||||
Returns:
|
||||
Dict with user info if valid, None if invalid
|
||||
Dict with user info including roles and groups if valid, None if invalid
|
||||
"""
|
||||
try:
|
||||
# Check key format
|
||||
|
|
@ -181,11 +200,15 @@ class APIKeyService:
|
|||
except Exception:
|
||||
pass # Don't fail validation if update fails
|
||||
|
||||
# Return user info with RBAC claims
|
||||
return {
|
||||
"key_id": key_doc["key_id"],
|
||||
"user_id": key_doc["user_id"],
|
||||
"user_email": key_doc["user_email"],
|
||||
"name": key_doc["name"],
|
||||
# RBAC fields - provide defaults for backward compatibility
|
||||
"roles": key_doc.get("roles", ["openrag_user"]),
|
||||
"groups": key_doc.get("groups", []),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -225,6 +248,8 @@ class APIKeyService:
|
|||
"created_at",
|
||||
"last_used_at",
|
||||
"revoked",
|
||||
"roles",
|
||||
"groups",
|
||||
],
|
||||
"size": 100,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ class ChatService:
|
|||
previous_response_id: str = None,
|
||||
stream: bool = False,
|
||||
filter_id: str = None,
|
||||
groups: list = None,
|
||||
roles: list = None,
|
||||
):
|
||||
"""Handle chat requests using the patched OpenAI client"""
|
||||
if not prompt:
|
||||
|
|
@ -23,7 +25,7 @@ class ChatService:
|
|||
|
||||
# Set authentication context for this request so tools can access it
|
||||
if user_id and jwt_token:
|
||||
set_auth_context(user_id, jwt_token)
|
||||
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||
|
||||
if stream:
|
||||
return async_chat_stream(
|
||||
|
|
@ -54,6 +56,8 @@ class ChatService:
|
|||
previous_response_id: str = None,
|
||||
stream: bool = False,
|
||||
filter_id: str = None,
|
||||
groups: list = None,
|
||||
roles: list = None,
|
||||
):
|
||||
"""Handle Langflow chat requests"""
|
||||
if not prompt:
|
||||
|
|
@ -346,9 +350,9 @@ class ChatService:
|
|||
previous_response_id=previous_response_id,
|
||||
)
|
||||
else: # chat
|
||||
# Set auth context for chat tools and provide user_id
|
||||
# Set auth context for chat tools and provide user_id with RBAC
|
||||
if user_id and jwt_token:
|
||||
set_auth_context(user_id, jwt_token)
|
||||
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||
response_text, response_id = await async_chat(
|
||||
clients.patched_llm_client,
|
||||
document_prompt,
|
||||
|
|
|
|||
219
src/services/group_service.py
Normal file
219
src/services/group_service.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""
|
||||
Group Service for managing user groups for RBAC.
|
||||
"""
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from config.settings import GROUPS_INDEX_NAME
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GroupService:
|
||||
"""Service for managing user groups for RBAC."""
|
||||
|
||||
def __init__(self, session_manager=None):
|
||||
self.session_manager = session_manager
|
||||
|
||||
async def _ensure_index_exists(self, opensearch_client) -> None:
|
||||
"""Ensure the groups index exists."""
|
||||
from config.settings import GROUPS_INDEX_BODY
|
||||
|
||||
try:
|
||||
exists = await opensearch_client.indices.exists(index=GROUPS_INDEX_NAME)
|
||||
if not exists:
|
||||
await opensearch_client.indices.create(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
body=GROUPS_INDEX_BODY,
|
||||
)
|
||||
logger.info(f"Created groups index: {GROUPS_INDEX_NAME}")
|
||||
except Exception as e:
|
||||
# Index might already exist from concurrent creation
|
||||
if "resource_already_exists_exception" not in str(e):
|
||||
logger.error(f"Failed to create groups index: {e}")
|
||||
raise
|
||||
|
||||
async def create_group(
|
||||
self,
|
||||
name: str,
|
||||
description: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new user group.
|
||||
|
||||
Args:
|
||||
name: The group name (must be unique)
|
||||
description: Optional description of the group
|
||||
|
||||
Returns:
|
||||
Dict with success status and group info
|
||||
"""
|
||||
try:
|
||||
# Get OpenSearch client
|
||||
from config.settings import clients
|
||||
|
||||
opensearch_client = clients.opensearch
|
||||
|
||||
# Ensure index exists
|
||||
await self._ensure_index_exists(opensearch_client)
|
||||
|
||||
# Check if group with this name already exists
|
||||
search_body = {
|
||||
"query": {"term": {"name": name}},
|
||||
"size": 1,
|
||||
}
|
||||
|
||||
result = await opensearch_client.search(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
body=search_body,
|
||||
)
|
||||
|
||||
if result.get("hits", {}).get("hits", []):
|
||||
return {"success": False, "error": f"Group '{name}' already exists"}
|
||||
|
||||
# Create a unique group_id
|
||||
group_id = secrets.token_urlsafe(16)
|
||||
now = datetime.utcnow().isoformat()
|
||||
|
||||
# Create the document to store
|
||||
group_doc = {
|
||||
"group_id": group_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"created_at": now,
|
||||
}
|
||||
|
||||
# Index the group document
|
||||
result = await opensearch_client.index(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
id=group_id,
|
||||
body=group_doc,
|
||||
refresh="wait_for",
|
||||
)
|
||||
|
||||
if result.get("result") in ("created", "updated"):
|
||||
logger.info(f"Created group: {name} (id: {group_id})")
|
||||
return {
|
||||
"success": True,
|
||||
"group_id": group_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"created_at": now,
|
||||
}
|
||||
else:
|
||||
return {"success": False, "error": "Failed to create group"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create group: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def list_groups(self) -> Dict[str, Any]:
|
||||
"""
|
||||
List all user groups.
|
||||
|
||||
Returns:
|
||||
Dict with list of groups
|
||||
"""
|
||||
try:
|
||||
# Get OpenSearch client
|
||||
from config.settings import clients
|
||||
|
||||
opensearch_client = clients.opensearch
|
||||
|
||||
# Ensure index exists
|
||||
await self._ensure_index_exists(opensearch_client)
|
||||
|
||||
# Search for all groups
|
||||
search_body = {
|
||||
"query": {"match_all": {}},
|
||||
"sort": [{"name": {"order": "asc"}}],
|
||||
"_source": ["group_id", "name", "description", "created_at"],
|
||||
"size": 1000,
|
||||
}
|
||||
|
||||
result = await opensearch_client.search(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
body=search_body,
|
||||
)
|
||||
|
||||
groups = []
|
||||
for hit in result.get("hits", {}).get("hits", []):
|
||||
groups.append(hit["_source"])
|
||||
|
||||
return {"success": True, "groups": groups}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list groups: {e}")
|
||||
return {"success": False, "error": str(e), "groups": []}
|
||||
|
||||
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get a group by ID.
|
||||
|
||||
Args:
|
||||
group_id: The group ID
|
||||
|
||||
Returns:
|
||||
Group info if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Get OpenSearch client
|
||||
from config.settings import clients
|
||||
|
||||
opensearch_client = clients.opensearch
|
||||
|
||||
doc = await opensearch_client.get(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
id=group_id,
|
||||
)
|
||||
|
||||
return doc["_source"]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_group(self, group_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete a user group.
|
||||
|
||||
Args:
|
||||
group_id: The group ID to delete
|
||||
|
||||
Returns:
|
||||
Dict with success status
|
||||
"""
|
||||
try:
|
||||
# Get OpenSearch client
|
||||
from config.settings import clients
|
||||
|
||||
opensearch_client = clients.opensearch
|
||||
|
||||
# Verify the group exists
|
||||
try:
|
||||
doc = await opensearch_client.get(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
id=group_id,
|
||||
)
|
||||
group_name = doc["_source"].get("name", "unknown")
|
||||
except Exception:
|
||||
return {"success": False, "error": "Group not found"}
|
||||
|
||||
# Delete the group
|
||||
result = await opensearch_client.delete(
|
||||
index=GROUPS_INDEX_NAME,
|
||||
id=group_id,
|
||||
refresh="wait_for",
|
||||
)
|
||||
|
||||
if result.get("result") == "deleted":
|
||||
logger.info(f"Deleted group: {group_name} (id: {group_id})")
|
||||
return {"success": True}
|
||||
else:
|
||||
return {"success": False, "error": "Failed to delete group"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete group: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
|
@ -2,7 +2,7 @@ import copy
|
|||
from typing import Any, Dict
|
||||
from agentd.tool_decorator import tool
|
||||
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
|
||||
from auth_context import get_auth_context
|
||||
from auth_context import get_auth_context, get_current_user_groups
|
||||
from utils.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
@ -259,8 +259,24 @@ class SearchService:
|
|||
# Build query body
|
||||
if is_wildcard_match_all:
|
||||
# Match all documents; still allow filters to narrow scope
|
||||
if filter_clauses:
|
||||
query_block = {"bool": {"filter": filter_clauses}}
|
||||
# Also add RBAC group filter for wildcard queries
|
||||
wildcard_filters = list(filter_clauses) # Copy existing filters
|
||||
|
||||
user_groups = get_current_user_groups()
|
||||
if user_groups:
|
||||
groups_access_filter = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||
{"terms": {"allowed_groups": user_groups}},
|
||||
],
|
||||
"minimum_should_match": 1
|
||||
}
|
||||
}
|
||||
wildcard_filters.append(groups_access_filter)
|
||||
|
||||
if wildcard_filters:
|
||||
query_block = {"bool": {"filter": wildcard_filters}}
|
||||
else:
|
||||
query_block = {"match_all": {}}
|
||||
else:
|
||||
|
|
@ -292,6 +308,29 @@ class SearchService:
|
|||
# Add exists filter to existing filters
|
||||
all_filters = [*filter_clauses, exists_any_embedding]
|
||||
|
||||
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
|
||||
# Documents are accessible if:
|
||||
# 1. No allowed_groups field exists (backward compatibility, open access)
|
||||
# 2. User's groups match any of the document's allowed_groups
|
||||
user_groups = get_current_user_groups()
|
||||
if user_groups:
|
||||
groups_access_filter = {
|
||||
"bool": {
|
||||
"should": [
|
||||
# Document has no allowed_groups restriction
|
||||
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||
# User's groups match document's allowed_groups
|
||||
{"terms": {"allowed_groups": user_groups}},
|
||||
],
|
||||
"minimum_should_match": 1
|
||||
}
|
||||
}
|
||||
all_filters.append(groups_access_filter)
|
||||
logger.debug(
|
||||
"Added RBAC group filter",
|
||||
user_groups=user_groups,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Building hybrid query with filters",
|
||||
user_filters_count=len(filter_clauses),
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import json
|
|||
import jwt
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
import os
|
||||
from utils.logging_config import get_logger
|
||||
|
|
@ -24,12 +24,19 @@ class User:
|
|||
provider: str = "google"
|
||||
created_at: datetime = None
|
||||
last_login: datetime = None
|
||||
# RBAC fields
|
||||
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
|
||||
groups: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
if self.last_login is None:
|
||||
self.last_login = datetime.now()
|
||||
if self.roles is None:
|
||||
self.roles = ["openrag_user"]
|
||||
if self.groups is None:
|
||||
self.groups = []
|
||||
|
||||
class AnonymousUser(User):
|
||||
"""Anonymous user"""
|
||||
|
|
@ -136,11 +143,31 @@ class SessionManager:
|
|||
# Create JWT token using the shared method
|
||||
return self.create_jwt_token(user)
|
||||
|
||||
def create_jwt_token(self, user: User) -> str:
|
||||
"""Create JWT token for an existing user"""
|
||||
def create_jwt_token(
|
||||
self,
|
||||
user: User,
|
||||
roles: Optional[List[str]] = None,
|
||||
groups: Optional[List[str]] = None,
|
||||
expiration_days: int = 7,
|
||||
) -> str:
|
||||
"""Create JWT token for an existing user.
|
||||
|
||||
Args:
|
||||
user: The User object to create a token for
|
||||
roles: Optional roles override (for restricted API key tokens)
|
||||
groups: Optional groups override (for restricted API key tokens)
|
||||
expiration_days: Token expiration in days (default 7)
|
||||
|
||||
Returns:
|
||||
Encoded JWT token string
|
||||
"""
|
||||
# Use OpenSearch-compatible issuer for OIDC validation
|
||||
oidc_issuer = "http://openrag-backend:8000"
|
||||
|
||||
# Use provided roles/groups or fall back to user's defaults
|
||||
effective_roles = roles if roles is not None else user.roles
|
||||
effective_groups = groups if groups is not None else user.groups
|
||||
|
||||
# Create JWT token with OIDC-compliant claims
|
||||
now = datetime.utcnow()
|
||||
token_payload = {
|
||||
|
|
@ -148,7 +175,7 @@ class SessionManager:
|
|||
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
||||
"sub": user.user_id, # Subject (user ID)
|
||||
"aud": ["opensearch", "openrag"], # Audience
|
||||
"exp": now + timedelta(days=7), # Expiration
|
||||
"exp": now + timedelta(days=expiration_days), # Expiration
|
||||
"iat": now, # Issued at
|
||||
"auth_time": int(now.timestamp()), # Authentication time
|
||||
# Custom claims
|
||||
|
|
@ -157,7 +184,9 @@ class SessionManager:
|
|||
"name": user.name,
|
||||
"preferred_username": user.email,
|
||||
"email_verified": True,
|
||||
"roles": ["openrag_user"], # Backend role for OpenSearch
|
||||
# RBAC claims
|
||||
"roles": effective_roles, # Backend roles for OpenSearch/tools
|
||||
"groups": effective_groups, # Group-based access control
|
||||
}
|
||||
|
||||
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue