Compare commits
2 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6faa77d5c7 | ||
|
|
ce51628db2 |
21 changed files with 1139 additions and 98 deletions
|
|
@ -6,6 +6,8 @@ import {
|
||||||
|
|
||||||
export interface CreateApiKeyRequest {
|
export interface CreateApiKeyRequest {
|
||||||
name: string;
|
name: string;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CreateApiKeyResponse {
|
export interface CreateApiKeyResponse {
|
||||||
|
|
@ -14,6 +16,8 @@ export interface CreateApiKeyResponse {
|
||||||
name: string;
|
name: string;
|
||||||
key_prefix: string;
|
key_prefix: string;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useCreateApiKeyMutation = (
|
export const useCreateApiKeyMutation = (
|
||||||
|
|
|
||||||
61
frontend/app/api/mutations/useCreateGroupMutation.ts
Normal file
61
frontend/app/api/mutations/useCreateGroupMutation.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
import {
|
||||||
|
type UseMutationOptions,
|
||||||
|
useMutation,
|
||||||
|
useQueryClient,
|
||||||
|
} from "@tanstack/react-query";
|
||||||
|
|
||||||
|
export interface CreateGroupRequest {
|
||||||
|
name: string;
|
||||||
|
description?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CreateGroupResponse {
|
||||||
|
success: boolean;
|
||||||
|
group_id: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
created_at: string;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useCreateGroupMutation = (
|
||||||
|
options?: Omit<
|
||||||
|
UseMutationOptions<CreateGroupResponse, Error, CreateGroupRequest>,
|
||||||
|
"mutationFn"
|
||||||
|
>,
|
||||||
|
) => {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
async function createGroup(
|
||||||
|
variables: CreateGroupRequest,
|
||||||
|
): Promise<CreateGroupResponse> {
|
||||||
|
const response = await fetch("/api/groups", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify(variables),
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(data.error || "Failed to create group");
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: createGroup,
|
||||||
|
onSuccess: (...args) => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["groups"],
|
||||||
|
});
|
||||||
|
options?.onSuccess?.(...args);
|
||||||
|
},
|
||||||
|
onError: options?.onError,
|
||||||
|
onSettled: options?.onSettled,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
52
frontend/app/api/mutations/useDeleteGroupMutation.ts
Normal file
52
frontend/app/api/mutations/useDeleteGroupMutation.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
||||||
|
import {
|
||||||
|
type UseMutationOptions,
|
||||||
|
useMutation,
|
||||||
|
useQueryClient,
|
||||||
|
} from "@tanstack/react-query";
|
||||||
|
|
||||||
|
export interface DeleteGroupRequest {
|
||||||
|
group_id: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DeleteGroupResponse {
|
||||||
|
success: boolean;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useDeleteGroupMutation = (
|
||||||
|
options?: Omit<
|
||||||
|
UseMutationOptions<DeleteGroupResponse, Error, DeleteGroupRequest>,
|
||||||
|
"mutationFn"
|
||||||
|
>,
|
||||||
|
) => {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
|
||||||
|
async function deleteGroup(
|
||||||
|
variables: DeleteGroupRequest,
|
||||||
|
): Promise<DeleteGroupResponse> {
|
||||||
|
const response = await fetch(`/api/groups/${variables.group_id}`, {
|
||||||
|
method: "DELETE",
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(data.error || "Failed to delete group");
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: deleteGroup,
|
||||||
|
onSuccess: (...args) => {
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ["groups"],
|
||||||
|
});
|
||||||
|
options?.onSuccess?.(...args);
|
||||||
|
},
|
||||||
|
onError: options?.onError,
|
||||||
|
onSettled: options?.onSettled,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
|
@ -6,6 +6,8 @@ export interface ApiKey {
|
||||||
key_prefix: string;
|
key_prefix: string;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
last_used_at: string | null;
|
last_used_at: string | null;
|
||||||
|
roles?: string[];
|
||||||
|
groups?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface GetApiKeysResponse {
|
export interface GetApiKeysResponse {
|
||||||
|
|
|
||||||
32
frontend/app/api/queries/useGetGroupsQuery.ts
Normal file
32
frontend/app/api/queries/useGetGroupsQuery.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
import { type UseQueryOptions, useQuery } from "@tanstack/react-query";
|
||||||
|
|
||||||
|
export interface Group {
|
||||||
|
group_id: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
created_at: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GetGroupsResponse {
|
||||||
|
success: boolean;
|
||||||
|
groups: Group[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export const useGetGroupsQuery = (
|
||||||
|
options?: Omit<UseQueryOptions<GetGroupsResponse>, "queryKey" | "queryFn">,
|
||||||
|
) => {
|
||||||
|
async function getGroups(): Promise<GetGroupsResponse> {
|
||||||
|
const response = await fetch("/api/groups");
|
||||||
|
if (response.ok) {
|
||||||
|
return await response.json();
|
||||||
|
}
|
||||||
|
throw new Error("Failed to fetch groups");
|
||||||
|
}
|
||||||
|
|
||||||
|
return useQuery({
|
||||||
|
queryKey: ["groups"],
|
||||||
|
queryFn: getGroups,
|
||||||
|
...options,
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
|
@ -12,8 +12,11 @@ import {
|
||||||
useGetOpenAIModelsQuery,
|
useGetOpenAIModelsQuery,
|
||||||
} from "@/app/api/queries/useGetModelsQuery";
|
} from "@/app/api/queries/useGetModelsQuery";
|
||||||
import { useGetApiKeysQuery } from "@/app/api/queries/useGetApiKeysQuery";
|
import { useGetApiKeysQuery } from "@/app/api/queries/useGetApiKeysQuery";
|
||||||
|
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
|
||||||
import { useCreateApiKeyMutation } from "@/app/api/mutations/useCreateApiKeyMutation";
|
import { useCreateApiKeyMutation } from "@/app/api/mutations/useCreateApiKeyMutation";
|
||||||
import { useRevokeApiKeyMutation } from "@/app/api/mutations/useRevokeApiKeyMutation";
|
import { useRevokeApiKeyMutation } from "@/app/api/mutations/useRevokeApiKeyMutation";
|
||||||
|
import { ManageGroupsModal } from "@/components/ManageGroupsModal";
|
||||||
|
import { MultiSelect } from "@/components/ui/multi-select";
|
||||||
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
import { useGetSettingsQuery } from "@/app/api/queries/useGetSettingsQuery";
|
||||||
import { ConfirmationDialog } from "@/components/confirmation-dialog";
|
import { ConfirmationDialog } from "@/components/confirmation-dialog";
|
||||||
import {
|
import {
|
||||||
|
|
@ -136,8 +139,10 @@ function KnowledgeSourcesPage() {
|
||||||
// API Keys state
|
// API Keys state
|
||||||
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
|
const [createKeyDialogOpen, setCreateKeyDialogOpen] = useState(false);
|
||||||
const [newKeyName, setNewKeyName] = useState("");
|
const [newKeyName, setNewKeyName] = useState("");
|
||||||
|
const [newKeyGroups, setNewKeyGroups] = useState<string[]>([]);
|
||||||
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
|
const [newlyCreatedKey, setNewlyCreatedKey] = useState<string | null>(null);
|
||||||
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
|
const [showKeyDialogOpen, setShowKeyDialogOpen] = useState(false);
|
||||||
|
const [manageGroupsOpen, setManageGroupsOpen] = useState(false);
|
||||||
|
|
||||||
// Fetch settings using React Query
|
// Fetch settings using React Query
|
||||||
const { data: settings = {} } = useGetSettingsQuery({
|
const { data: settings = {} } = useGetSettingsQuery({
|
||||||
|
|
@ -149,6 +154,11 @@ function KnowledgeSourcesPage() {
|
||||||
enabled: isAuthenticated || isNoAuthMode,
|
enabled: isAuthenticated || isNoAuthMode,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Fetch user groups
|
||||||
|
const { data: groupsData } = useGetGroupsQuery({
|
||||||
|
enabled: isAuthenticated || isNoAuthMode,
|
||||||
|
});
|
||||||
|
|
||||||
// API key mutations
|
// API key mutations
|
||||||
const createApiKeyMutation = useCreateApiKeyMutation({
|
const createApiKeyMutation = useCreateApiKeyMutation({
|
||||||
onSuccess: (data) => {
|
onSuccess: (data) => {
|
||||||
|
|
@ -156,6 +166,7 @@ function KnowledgeSourcesPage() {
|
||||||
setCreateKeyDialogOpen(false);
|
setCreateKeyDialogOpen(false);
|
||||||
setShowKeyDialogOpen(true);
|
setShowKeyDialogOpen(true);
|
||||||
setNewKeyName("");
|
setNewKeyName("");
|
||||||
|
setNewKeyGroups([]);
|
||||||
toast.success("API key created");
|
toast.success("API key created");
|
||||||
},
|
},
|
||||||
onError: (error) => {
|
onError: (error) => {
|
||||||
|
|
@ -438,7 +449,11 @@ function KnowledgeSourcesPage() {
|
||||||
toast.error("Please enter a name for the API key");
|
toast.error("Please enter a name for the API key");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
createApiKeyMutation.mutate({ name: newKeyName.trim() });
|
// Use selected groups directly (already an array)
|
||||||
|
createApiKeyMutation.mutate({
|
||||||
|
name: newKeyName.trim(),
|
||||||
|
groups: newKeyGroups.length > 0 ? newKeyGroups : undefined,
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleRevokeApiKey = (keyId: string) => {
|
const handleRevokeApiKey = (keyId: string) => {
|
||||||
|
|
@ -1426,6 +1441,9 @@ function KnowledgeSourcesPage() {
|
||||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
Key
|
Key
|
||||||
</th>
|
</th>
|
||||||
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
|
Groups
|
||||||
|
</th>
|
||||||
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
<th className="text-left text-sm font-medium text-muted-foreground px-4 py-3">
|
||||||
Created
|
Created
|
||||||
</th>
|
</th>
|
||||||
|
|
@ -1448,6 +1466,24 @@ function KnowledgeSourcesPage() {
|
||||||
{key.key_prefix}...
|
{key.key_prefix}...
|
||||||
</code>
|
</code>
|
||||||
</td>
|
</td>
|
||||||
|
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||||
|
{key.groups && key.groups.length > 0 ? (
|
||||||
|
<div className="flex flex-wrap gap-1">
|
||||||
|
{key.groups.map((group: string) => (
|
||||||
|
<span
|
||||||
|
key={group}
|
||||||
|
className="inline-flex items-center px-2 py-0.5 rounded-full text-xs bg-primary/10 text-primary"
|
||||||
|
>
|
||||||
|
{group}
|
||||||
|
</span>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<span className="text-muted-foreground/50">
|
||||||
|
All groups
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
<td className="px-4 py-3 text-sm text-muted-foreground">
|
<td className="px-4 py-3 text-sm text-muted-foreground">
|
||||||
{formatDate(key.created_at)}
|
{formatDate(key.created_at)}
|
||||||
</td>
|
</td>
|
||||||
|
|
@ -1507,58 +1543,88 @@ function KnowledgeSourcesPage() {
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Create API Key Dialog */}
|
{/* Create API Key Dialog */}
|
||||||
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
<Dialog open={createKeyDialogOpen} onOpenChange={setCreateKeyDialogOpen}>
|
||||||
<DialogContent>
|
<DialogContent>
|
||||||
<DialogHeader>
|
<DialogHeader>
|
||||||
<DialogTitle>Create API Key</DialogTitle>
|
<DialogTitle>Create API Key</DialogTitle>
|
||||||
<DialogDescription>
|
<DialogDescription>
|
||||||
Give your API key a name to help you identify it later.
|
Create an API key with optional group restrictions for access
|
||||||
</DialogDescription>
|
control.
|
||||||
</DialogHeader>
|
</DialogDescription>
|
||||||
<div className="py-4">
|
</DialogHeader>
|
||||||
<LabelWrapper label="Name" id="api-key-name">
|
<div className="py-4 space-y-4">
|
||||||
<Input
|
<LabelWrapper label="Name" id="api-key-name">
|
||||||
id="api-key-name"
|
<Input
|
||||||
placeholder="e.g., Production App, Development"
|
id="api-key-name"
|
||||||
value={newKeyName}
|
placeholder="e.g., Production App, Development"
|
||||||
onChange={(e) => setNewKeyName(e.target.value)}
|
value={newKeyName}
|
||||||
onKeyDown={(e) => {
|
onChange={(e) => setNewKeyName(e.target.value)}
|
||||||
if (e.key === "Enter") {
|
onKeyDown={(e) => {
|
||||||
handleCreateApiKey();
|
if (e.key === "Enter") {
|
||||||
}
|
handleCreateApiKey();
|
||||||
}}
|
}
|
||||||
/>
|
|
||||||
</LabelWrapper>
|
|
||||||
</div>
|
|
||||||
<DialogFooter>
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
onClick={() => {
|
|
||||||
setCreateKeyDialogOpen(false);
|
|
||||||
setNewKeyName("");
|
|
||||||
}}
|
}}
|
||||||
size="sm"
|
/>
|
||||||
>
|
</LabelWrapper>
|
||||||
Cancel
|
<div className="space-y-2">
|
||||||
</Button>
|
<LabelWrapper
|
||||||
<Button
|
label="Groups (optional)"
|
||||||
onClick={handleCreateApiKey}
|
id="api-key-groups"
|
||||||
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
helperText="Restrict this key to specific groups"
|
||||||
size="sm"
|
>
|
||||||
>
|
<MultiSelect
|
||||||
{createApiKeyMutation.isPending ? (
|
options={(groupsData?.groups || []).map((g) => ({
|
||||||
<>
|
value: g.name,
|
||||||
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
label: g.name,
|
||||||
Creating...
|
}))}
|
||||||
</>
|
value={newKeyGroups}
|
||||||
) : (
|
onValueChange={setNewKeyGroups}
|
||||||
"Create Key"
|
placeholder="Select groups..."
|
||||||
)}
|
showAllOption={false}
|
||||||
</Button>
|
searchPlaceholder="Search groups..."
|
||||||
</DialogFooter>
|
/>
|
||||||
</DialogContent>
|
</LabelWrapper>
|
||||||
</Dialog>
|
<Button
|
||||||
|
variant="link"
|
||||||
|
size="sm"
|
||||||
|
className="h-auto p-0 text-xs"
|
||||||
|
onClick={() => setManageGroupsOpen(true)}
|
||||||
|
>
|
||||||
|
<Plus className="h-3 w-3 mr-1" />
|
||||||
|
Manage Groups
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<DialogFooter>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
onClick={() => {
|
||||||
|
setCreateKeyDialogOpen(false);
|
||||||
|
setNewKeyName("");
|
||||||
|
setNewKeyGroups([]);
|
||||||
|
}}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
onClick={handleCreateApiKey}
|
||||||
|
disabled={createApiKeyMutation.isPending || !newKeyName.trim()}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
{createApiKeyMutation.isPending ? (
|
||||||
|
<>
|
||||||
|
<Loader2 className="h-4 w-4 mr-2 animate-spin" />
|
||||||
|
Creating...
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
"Create Key"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
{/* Show Created API Key Dialog */}
|
{/* Show Created API Key Dialog */}
|
||||||
<Dialog
|
<Dialog
|
||||||
|
|
@ -1593,6 +1659,12 @@ function KnowledgeSourcesPage() {
|
||||||
</DialogFooter>
|
</DialogFooter>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
|
{/* Manage Groups Modal */}
|
||||||
|
<ManageGroupsModal
|
||||||
|
open={manageGroupsOpen}
|
||||||
|
onOpenChange={setManageGroupsOpen}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
175
frontend/components/ManageGroupsModal.tsx
Normal file
175
frontend/components/ManageGroupsModal.tsx
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/ui/dialog";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { useGetGroupsQuery } from "@/app/api/queries/useGetGroupsQuery";
|
||||||
|
import { useCreateGroupMutation } from "@/app/api/mutations/useCreateGroupMutation";
|
||||||
|
import { useDeleteGroupMutation } from "@/app/api/mutations/useDeleteGroupMutation";
|
||||||
|
import { Plus, Trash2, Loader2, Users } from "lucide-react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
interface ManageGroupsModalProps {
|
||||||
|
open: boolean;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ManageGroupsModal({
|
||||||
|
open,
|
||||||
|
onOpenChange,
|
||||||
|
}: ManageGroupsModalProps) {
|
||||||
|
const [newGroupName, setNewGroupName] = useState("");
|
||||||
|
const [newGroupDescription, setNewGroupDescription] = useState("");
|
||||||
|
|
||||||
|
const { data: groupsData, isLoading: groupsLoading } = useGetGroupsQuery({
|
||||||
|
enabled: open,
|
||||||
|
});
|
||||||
|
|
||||||
|
const createGroupMutation = useCreateGroupMutation({
|
||||||
|
onSuccess: () => {
|
||||||
|
setNewGroupName("");
|
||||||
|
setNewGroupDescription("");
|
||||||
|
toast.success("Group created successfully");
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
toast.error("Failed to create group", { description: error.message });
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteGroupMutation = useDeleteGroupMutation({
|
||||||
|
onSuccess: () => {
|
||||||
|
toast.success("Group deleted successfully");
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
toast.error("Failed to delete group", { description: error.message });
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleCreateGroup = () => {
|
||||||
|
if (!newGroupName.trim()) {
|
||||||
|
toast.error("Please enter a group name");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
createGroupMutation.mutate({
|
||||||
|
name: newGroupName.trim(),
|
||||||
|
description: newGroupDescription.trim(),
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeleteGroup = (groupId: string, groupName: string) => {
|
||||||
|
if (confirm(`Are you sure you want to delete the group "${groupName}"?`)) {
|
||||||
|
deleteGroupMutation.mutate({ group_id: groupId });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const groups = groupsData?.groups || [];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||||
|
<DialogContent className="max-w-md">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>Manage User Groups</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
Create and manage user groups for access control. Groups can be
|
||||||
|
assigned to API keys to restrict document access.
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div className="space-y-4 py-4">
|
||||||
|
{/* Add new group section */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<label className="text-sm font-medium">Create New Group</label>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Input
|
||||||
|
placeholder="Group name"
|
||||||
|
value={newGroupName}
|
||||||
|
onChange={(e) => setNewGroupName(e.target.value)}
|
||||||
|
onKeyDown={(e) => {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
handleCreateGroup();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
onClick={handleCreateGroup}
|
||||||
|
disabled={
|
||||||
|
createGroupMutation.isPending || !newGroupName.trim()
|
||||||
|
}
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
{createGroupMutation.isPending ? (
|
||||||
|
<Loader2 className="h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<Plus className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
placeholder="Description (optional)"
|
||||||
|
value={newGroupDescription}
|
||||||
|
onChange={(e) => setNewGroupDescription(e.target.value)}
|
||||||
|
className="text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Existing groups list */}
|
||||||
|
<div className="space-y-2">
|
||||||
|
<label className="text-sm font-medium">Existing Groups</label>
|
||||||
|
{groupsLoading ? (
|
||||||
|
<div className="flex items-center justify-center py-4">
|
||||||
|
<Loader2 className="h-5 w-5 animate-spin text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
) : groups.length > 0 ? (
|
||||||
|
<div className="border rounded-lg divide-y max-h-48 overflow-y-auto">
|
||||||
|
{groups.map((group) => (
|
||||||
|
<div
|
||||||
|
key={group.group_id}
|
||||||
|
className="flex items-center justify-between px-3 py-2 hover:bg-muted/50"
|
||||||
|
>
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
|
<p className="text-sm font-medium truncate">
|
||||||
|
{group.name}
|
||||||
|
</p>
|
||||||
|
{group.description && (
|
||||||
|
<p className="text-xs text-muted-foreground truncate">
|
||||||
|
{group.description}
|
||||||
|
</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="h-8 w-8 p-0 text-destructive hover:text-destructive hover:bg-destructive/10"
|
||||||
|
onClick={() =>
|
||||||
|
handleDeleteGroup(group.group_id, group.name)
|
||||||
|
}
|
||||||
|
disabled={deleteGroupMutation.isPending}
|
||||||
|
>
|
||||||
|
<Trash2 className="h-4 w-4" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="text-center py-6 border rounded-lg">
|
||||||
|
<Users className="h-8 w-8 mx-auto text-muted-foreground/50 mb-2" />
|
||||||
|
<p className="text-sm text-muted-foreground">
|
||||||
|
No groups yet. Create one above.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -35,6 +35,10 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
set_search_limit(limit)
|
set_search_limit(limit)
|
||||||
set_score_threshold(score_threshold)
|
set_score_threshold(score_threshold)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
await chat_service.chat(
|
await chat_service.chat(
|
||||||
|
|
@ -44,6 +48,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -61,6 +67,8 @@ async def chat_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
@ -81,6 +89,10 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
return JSONResponse({"error": "Prompt is required"}, status_code=400)
|
||||||
|
|
||||||
|
|
@ -105,6 +117,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
|
|
@ -122,6 +136,8 @@ async def langflow_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
return JSONResponse(result)
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
|
||||||
114
src/api/groups.py
Normal file
114
src/api/groups.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""
|
||||||
|
User Groups management endpoints.
|
||||||
|
|
||||||
|
These endpoints allow managing user groups for RBAC.
|
||||||
|
"""
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_groups_endpoint(request: Request, group_service):
|
||||||
|
"""
|
||||||
|
List all user groups.
|
||||||
|
|
||||||
|
GET /groups
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"groups": [
|
||||||
|
{
|
||||||
|
"group_id": "...",
|
||||||
|
"name": "finance",
|
||||||
|
"description": "Finance team",
|
||||||
|
"created_at": "2024-01-01T00:00:00"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = await group_service.list_groups()
|
||||||
|
return JSONResponse(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_group_endpoint(request: Request, group_service):
|
||||||
|
"""
|
||||||
|
Create a new user group.
|
||||||
|
|
||||||
|
POST /groups
|
||||||
|
Body: {"name": "finance", "description": "Finance team"}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"group_id": "...",
|
||||||
|
"name": "finance",
|
||||||
|
"description": "Finance team",
|
||||||
|
"created_at": "2024-01-01T00:00:00"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
name = data.get("name", "").strip()
|
||||||
|
description = data.get("description", "").strip()
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "Name is required"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(name) > 100:
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "Name must be 100 characters or less"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await group_service.create_group(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
return JSONResponse(result)
|
||||||
|
elif "already exists" in result.get("error", ""):
|
||||||
|
return JSONResponse(result, status_code=409)
|
||||||
|
else:
|
||||||
|
return JSONResponse(result, status_code=500)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create group: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": str(e)},
|
||||||
|
status_code=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_group_endpoint(request: Request, group_service):
|
||||||
|
"""
|
||||||
|
Delete a user group.
|
||||||
|
|
||||||
|
DELETE /groups/{group_id}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{"success": true}
|
||||||
|
"""
|
||||||
|
group_id = request.path_params.get("group_id")
|
||||||
|
|
||||||
|
if not group_id:
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "Group ID is required"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await group_service.delete_group(group_id=group_id)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
return JSONResponse(result)
|
||||||
|
elif result.get("error") == "Group not found":
|
||||||
|
return JSONResponse(result, status_code=404)
|
||||||
|
else:
|
||||||
|
return JSONResponse(result, status_code=500)
|
||||||
|
|
||||||
|
|
@ -42,10 +42,14 @@ async def list_keys_endpoint(request: Request, api_key_service):
|
||||||
|
|
||||||
async def create_key_endpoint(request: Request, api_key_service):
|
async def create_key_endpoint(request: Request, api_key_service):
|
||||||
"""
|
"""
|
||||||
Create a new API key for the authenticated user.
|
Create a new API key for the authenticated user with optional RBAC restrictions.
|
||||||
|
|
||||||
POST /keys
|
POST /keys
|
||||||
Body: {"name": "My API Key"}
|
Body: {
|
||||||
|
"name": "My API Key",
|
||||||
|
"roles": ["openrag_user"], // Optional: restrict key to specific roles
|
||||||
|
"groups": ["finance", "hr"] // Optional: restrict key to specific groups
|
||||||
|
}
|
||||||
|
|
||||||
Response:
|
Response:
|
||||||
{
|
{
|
||||||
|
|
@ -53,6 +57,8 @@ async def create_key_endpoint(request: Request, api_key_service):
|
||||||
"key_id": "...",
|
"key_id": "...",
|
||||||
"key_prefix": "orag_abc12345",
|
"key_prefix": "orag_abc12345",
|
||||||
"name": "My API Key",
|
"name": "My API Key",
|
||||||
|
"roles": ["openrag_user"],
|
||||||
|
"groups": ["finance", "hr"],
|
||||||
"created_at": "2024-01-01T00:00:00",
|
"created_at": "2024-01-01T00:00:00",
|
||||||
"api_key": "orag_abc12345..." // Full key, only shown once!
|
"api_key": "orag_abc12345..." // Full key, only shown once!
|
||||||
}
|
}
|
||||||
|
|
@ -78,11 +84,29 @@ async def create_key_endpoint(request: Request, api_key_service):
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract optional RBAC fields
|
||||||
|
roles = data.get("roles")
|
||||||
|
groups = data.get("groups")
|
||||||
|
|
||||||
|
# Validate roles and groups are lists if provided
|
||||||
|
if roles is not None and not isinstance(roles, list):
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "roles must be a list"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if groups is not None and not isinstance(groups, list):
|
||||||
|
return JSONResponse(
|
||||||
|
{"success": False, "error": "groups must be a list"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
result = await api_key_service.create_key(
|
result = await api_key_service.create_key(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
user_email=user_email,
|
user_email=user_email,
|
||||||
name=name,
|
name=name,
|
||||||
jwt_token=jwt_token,
|
jwt_token=jwt_token,
|
||||||
|
roles=roles,
|
||||||
|
groups=groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
|
|
|
||||||
|
|
@ -104,14 +104,18 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
|
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
user_id = user.user_id
|
user_id = user.user_id
|
||||||
jwt_token = session_manager.get_effective_jwt_token(user_id, None)
|
jwt_token = session_manager.get_effective_jwt_token(user_id, request.state.jwt_token)
|
||||||
|
|
||||||
|
# Get RBAC groups and roles from user for access control
|
||||||
|
user_groups = getattr(user, "groups", [])
|
||||||
|
user_roles = getattr(user, "roles", [])
|
||||||
|
|
||||||
# Set context variables for search tool
|
# Set context variables for search tool
|
||||||
if filters:
|
if filters:
|
||||||
set_search_filters(filters)
|
set_search_filters(filters)
|
||||||
set_search_limit(limit)
|
set_search_limit(limit)
|
||||||
set_score_threshold(score_threshold)
|
set_score_threshold(score_threshold)
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=user_groups, roles=user_roles)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
raw_stream = await chat_service.langflow_chat(
|
raw_stream = await chat_service.langflow_chat(
|
||||||
|
|
@ -121,6 +125,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=chat_id,
|
previous_response_id=chat_id,
|
||||||
stream=True,
|
stream=True,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
chat_id_container = {}
|
chat_id_container = {}
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -136,6 +142,8 @@ async def chat_create_endpoint(request: Request, chat_service, session_manager):
|
||||||
previous_response_id=chat_id,
|
previous_response_id=chat_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
filter_id=filter_id,
|
filter_id=filter_id,
|
||||||
|
groups=user_groups,
|
||||||
|
roles=user_roles,
|
||||||
)
|
)
|
||||||
# Transform response_id to chat_id for v1 API format
|
# Transform response_id to chat_id for v1 API format
|
||||||
return JSONResponse({
|
return JSONResponse({
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
"""
|
"""
|
||||||
API Key middleware for authenticating public API requests.
|
API Key middleware for authenticating public API requests.
|
||||||
|
|
||||||
|
This middleware validates API keys and generates ephemeral JWTs with the
|
||||||
|
key's specific roles and groups for downstream security enforcement.
|
||||||
"""
|
"""
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
@ -32,14 +35,18 @@ def _extract_api_key(request: Request) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def require_api_key(api_key_service):
|
def require_api_key(api_key_service, session_manager=None):
|
||||||
"""
|
"""
|
||||||
Decorator to require API key authentication for public API endpoints.
|
Decorator to require API key authentication for public API endpoints.
|
||||||
|
|
||||||
|
Generates an ephemeral JWT with the API key's specific roles and groups
|
||||||
|
to enforce RBAC in downstream services (OpenSearch, tools, etc.).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@require_api_key(api_key_service)
|
@require_api_key(api_key_service, session_manager)
|
||||||
async def my_endpoint(request):
|
async def my_endpoint(request):
|
||||||
user = request.state.user
|
user = request.state.user
|
||||||
|
jwt_token = request.state.jwt_token # Ephemeral restricted JWT
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -57,7 +64,7 @@ def require_api_key(api_key_service):
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate the key
|
# Validate the key and get RBAC claims
|
||||||
user_info = await api_key_service.validate_key(api_key)
|
user_info = await api_key_service.validate_key(api_key)
|
||||||
|
|
||||||
if not user_info:
|
if not user_info:
|
||||||
|
|
@ -69,19 +76,46 @@ def require_api_key(api_key_service):
|
||||||
status_code=401,
|
status_code=401,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a User object from the API key info
|
# Extract RBAC fields from API key
|
||||||
|
key_roles = user_info.get("roles", ["openrag_user"])
|
||||||
|
key_groups = user_info.get("groups", [])
|
||||||
|
|
||||||
|
# Create a User object with the API key's roles and groups
|
||||||
user = User(
|
user = User(
|
||||||
user_id=user_info["user_id"],
|
user_id=user_info["user_id"],
|
||||||
email=user_info["user_email"],
|
email=user_info["user_email"],
|
||||||
name=user_info.get("name", "API User"),
|
name=user_info.get("name", "API User"),
|
||||||
picture=None,
|
picture=None,
|
||||||
provider="api_key",
|
provider="api_key",
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set request state
|
# Set request state
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
request.state.api_key_id = user_info["key_id"]
|
request.state.api_key_id = user_info["key_id"]
|
||||||
request.state.jwt_token = None # No JWT for API key auth
|
|
||||||
|
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||||
|
if session_manager:
|
||||||
|
# Create a short-lived JWT with the key's specific permissions
|
||||||
|
ephemeral_jwt = session_manager.create_jwt_token(
|
||||||
|
user=user,
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
expiration_days=1, # Short-lived for API requests
|
||||||
|
)
|
||||||
|
request.state.jwt_token = ephemeral_jwt
|
||||||
|
logger.debug(
|
||||||
|
"Generated ephemeral JWT for API key",
|
||||||
|
key_id=user_info["key_id"],
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request.state.jwt_token = None
|
||||||
|
logger.warning(
|
||||||
|
"No session_manager provided - JWT not generated for API key"
|
||||||
|
)
|
||||||
|
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
|
|
@ -90,10 +124,13 @@ def require_api_key(api_key_service):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def optional_api_key(api_key_service):
|
def optional_api_key(api_key_service, session_manager=None):
|
||||||
"""
|
"""
|
||||||
Decorator to optionally authenticate with API key.
|
Decorator to optionally authenticate with API key.
|
||||||
Sets request.state.user to None if no valid API key is provided.
|
Sets request.state.user to None if no valid API key is provided.
|
||||||
|
|
||||||
|
When a valid API key is provided, generates an ephemeral JWT with
|
||||||
|
the key's specific roles and groups.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(handler):
|
def decorator(handler):
|
||||||
|
|
@ -102,21 +139,38 @@ def optional_api_key(api_key_service):
|
||||||
api_key = _extract_api_key(request)
|
api_key = _extract_api_key(request)
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# Validate the key
|
# Validate the key and get RBAC claims
|
||||||
user_info = await api_key_service.validate_key(api_key)
|
user_info = await api_key_service.validate_key(api_key)
|
||||||
|
|
||||||
if user_info:
|
if user_info:
|
||||||
# Create a User object from the API key info
|
# Extract RBAC fields from API key
|
||||||
|
key_roles = user_info.get("roles", ["openrag_user"])
|
||||||
|
key_groups = user_info.get("groups", [])
|
||||||
|
|
||||||
|
# Create a User object with the API key's roles and groups
|
||||||
user = User(
|
user = User(
|
||||||
user_id=user_info["user_id"],
|
user_id=user_info["user_id"],
|
||||||
email=user_info["user_email"],
|
email=user_info["user_email"],
|
||||||
name=user_info.get("name", "API User"),
|
name=user_info.get("name", "API User"),
|
||||||
picture=None,
|
picture=None,
|
||||||
provider="api_key",
|
provider="api_key",
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
)
|
)
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
request.state.api_key_id = user_info["key_id"]
|
request.state.api_key_id = user_info["key_id"]
|
||||||
request.state.jwt_token = None
|
|
||||||
|
# Generate ephemeral JWT with the API key's restricted roles/groups
|
||||||
|
if session_manager:
|
||||||
|
ephemeral_jwt = session_manager.create_jwt_token(
|
||||||
|
user=user,
|
||||||
|
roles=key_roles,
|
||||||
|
groups=key_groups,
|
||||||
|
expiration_days=1,
|
||||||
|
)
|
||||||
|
request.state.jwt_token = ephemeral_jwt
|
||||||
|
else:
|
||||||
|
request.state.jwt_token = None
|
||||||
else:
|
else:
|
||||||
request.state.user = None
|
request.state.user = None
|
||||||
request.state.api_key_id = None
|
request.state.api_key_id = None
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ Uses contextvars to safely pass user auth info through async calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any, List
|
||||||
|
|
||||||
# Context variables for current request authentication
|
# Context variables for current request authentication
|
||||||
_current_user_id: ContextVar[Optional[str]] = ContextVar(
|
_current_user_id: ContextVar[Optional[str]] = ContextVar(
|
||||||
|
|
@ -13,6 +13,12 @@ _current_user_id: ContextVar[Optional[str]] = ContextVar(
|
||||||
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
|
_current_jwt_token: ContextVar[Optional[str]] = ContextVar(
|
||||||
"current_jwt_token", default=None
|
"current_jwt_token", default=None
|
||||||
)
|
)
|
||||||
|
_current_user_groups: ContextVar[Optional[List[str]]] = ContextVar(
|
||||||
|
"current_user_groups", default=None
|
||||||
|
)
|
||||||
|
_current_user_roles: ContextVar[Optional[List[str]]] = ContextVar(
|
||||||
|
"current_user_roles", default=None
|
||||||
|
)
|
||||||
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
_current_search_filters: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
|
||||||
"current_search_filters", default=None
|
"current_search_filters", default=None
|
||||||
)
|
)
|
||||||
|
|
@ -24,10 +30,24 @@ _current_score_threshold: ContextVar[Optional[float]] = ContextVar(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_auth_context(user_id: str, jwt_token: str):
|
def set_auth_context(
|
||||||
"""Set authentication context for the current async context"""
|
user_id: str,
|
||||||
|
jwt_token: str,
|
||||||
|
groups: Optional[List[str]] = None,
|
||||||
|
roles: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Set authentication context for the current async context
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
jwt_token: The JWT token for authentication
|
||||||
|
groups: Optional list of groups the user belongs to (for RBAC)
|
||||||
|
roles: Optional list of roles the user has (for RBAC)
|
||||||
|
"""
|
||||||
_current_user_id.set(user_id)
|
_current_user_id.set(user_id)
|
||||||
_current_jwt_token.set(jwt_token)
|
_current_jwt_token.set(jwt_token)
|
||||||
|
_current_user_groups.set(groups or [])
|
||||||
|
_current_user_roles.set(roles or [])
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_id() -> Optional[str]:
|
def get_current_user_id() -> Optional[str]:
|
||||||
|
|
@ -40,6 +60,16 @@ def get_current_jwt_token() -> Optional[str]:
|
||||||
return _current_jwt_token.get()
|
return _current_jwt_token.get()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_groups() -> List[str]:
|
||||||
|
"""Get current user's groups from context (for RBAC)"""
|
||||||
|
return _current_user_groups.get() or []
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_roles() -> List[str]:
|
||||||
|
"""Get current user's roles from context (for RBAC)"""
|
||||||
|
return _current_user_roles.get() or []
|
||||||
|
|
||||||
|
|
||||||
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
def get_auth_context() -> tuple[Optional[str], Optional[str]]:
|
||||||
"""Get current authentication context (user_id, jwt_token)"""
|
"""Get current authentication context (user_id, jwt_token)"""
|
||||||
return _current_user_id.get(), _current_jwt_token.get()
|
return _current_user_id.get(), _current_jwt_token.get()
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,23 @@ API_KEYS_INDEX_BODY = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# User Groups index for RBAC management
|
||||||
|
GROUPS_INDEX_NAME = "openrag_groups"
|
||||||
|
GROUPS_INDEX_BODY = {
|
||||||
|
"settings": {
|
||||||
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0,
|
||||||
|
},
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"group_id": {"type": "keyword"},
|
||||||
|
"name": {"type": "keyword"},
|
||||||
|
"description": {"type": "text"},
|
||||||
|
"created_at": {"type": "date"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
# Convenience base URL for Langflow REST API
|
# Convenience base URL for Langflow REST API
|
||||||
LANGFLOW_BASE_URL = f"{LANGFLOW_URL}/api/v1"
|
LANGFLOW_BASE_URL = f"{LANGFLOW_URL}/api/v1"
|
||||||
|
|
||||||
|
|
|
||||||
69
src/main.py
69
src/main.py
|
|
@ -58,6 +58,10 @@ from auth_middleware import optional_auth, require_auth
|
||||||
from api_key_middleware import require_api_key
|
from api_key_middleware import require_api_key
|
||||||
from services.api_key_service import APIKeyService
|
from services.api_key_service import APIKeyService
|
||||||
from api import keys as api_keys
|
from api import keys as api_keys
|
||||||
|
|
||||||
|
# User Groups management
|
||||||
|
from services.group_service import GroupService
|
||||||
|
from api import groups as api_groups
|
||||||
from api.v1 import chat as v1_chat, search as v1_search, documents as v1_documents, settings as v1_settings, knowledge_filters as v1_knowledge_filters
|
from api.v1 import chat as v1_chat, search as v1_search, documents as v1_documents, settings as v1_settings, knowledge_filters as v1_knowledge_filters
|
||||||
|
|
||||||
# Configuration and setup
|
# Configuration and setup
|
||||||
|
|
@ -665,6 +669,9 @@ async def initialize_services():
|
||||||
# API Key service for public API authentication
|
# API Key service for public API authentication
|
||||||
api_key_service = APIKeyService(session_manager)
|
api_key_service = APIKeyService(session_manager)
|
||||||
|
|
||||||
|
# Group service for RBAC management
|
||||||
|
group_service = GroupService(session_manager)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"document_service": document_service,
|
"document_service": document_service,
|
||||||
"search_service": search_service,
|
"search_service": search_service,
|
||||||
|
|
@ -679,6 +686,7 @@ async def initialize_services():
|
||||||
"monitor_service": monitor_service,
|
"monitor_service": monitor_service,
|
||||||
"session_manager": session_manager,
|
"session_manager": session_manager,
|
||||||
"api_key_service": api_key_service,
|
"api_key_service": api_key_service,
|
||||||
|
"group_service": group_service,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1310,11 +1318,42 @@ async def create_app():
|
||||||
),
|
),
|
||||||
methods=["DELETE"],
|
methods=["DELETE"],
|
||||||
),
|
),
|
||||||
|
# ===== User Groups Management Endpoints (JWT auth for UI) =====
|
||||||
|
Route(
|
||||||
|
"/groups",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
api_groups.list_groups_endpoint,
|
||||||
|
group_service=services["group_service"],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["GET"],
|
||||||
|
),
|
||||||
|
Route(
|
||||||
|
"/groups",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
api_groups.create_group_endpoint,
|
||||||
|
group_service=services["group_service"],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["POST"],
|
||||||
|
),
|
||||||
|
Route(
|
||||||
|
"/groups/{group_id}",
|
||||||
|
require_auth(services["session_manager"])(
|
||||||
|
partial(
|
||||||
|
api_groups.delete_group_endpoint,
|
||||||
|
group_service=services["group_service"],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
methods=["DELETE"],
|
||||||
|
),
|
||||||
# ===== Public API v1 Endpoints (API Key auth) =====
|
# ===== Public API v1 Endpoints (API Key auth) =====
|
||||||
# Chat endpoints
|
# Chat endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat",
|
"/v1/chat",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_create_endpoint,
|
v1_chat.chat_create_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1325,7 +1364,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat",
|
"/v1/chat",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_list_endpoint,
|
v1_chat.chat_list_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1336,7 +1375,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat/{chat_id}",
|
"/v1/chat/{chat_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_get_endpoint,
|
v1_chat.chat_get_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1347,7 +1386,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/chat/{chat_id}",
|
"/v1/chat/{chat_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_chat.chat_delete_endpoint,
|
v1_chat.chat_delete_endpoint,
|
||||||
chat_service=services["chat_service"],
|
chat_service=services["chat_service"],
|
||||||
|
|
@ -1359,7 +1398,7 @@ async def create_app():
|
||||||
# Search endpoint
|
# Search endpoint
|
||||||
Route(
|
Route(
|
||||||
"/v1/search",
|
"/v1/search",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_search.search_endpoint,
|
v1_search.search_endpoint,
|
||||||
search_service=services["search_service"],
|
search_service=services["search_service"],
|
||||||
|
|
@ -1371,7 +1410,7 @@ async def create_app():
|
||||||
# Documents endpoints
|
# Documents endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/documents/ingest",
|
"/v1/documents/ingest",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.ingest_endpoint,
|
v1_documents.ingest_endpoint,
|
||||||
document_service=services["document_service"],
|
document_service=services["document_service"],
|
||||||
|
|
@ -1384,7 +1423,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/tasks/{task_id}",
|
"/v1/tasks/{task_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.task_status_endpoint,
|
v1_documents.task_status_endpoint,
|
||||||
task_service=services["task_service"],
|
task_service=services["task_service"],
|
||||||
|
|
@ -1395,7 +1434,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/documents",
|
"/v1/documents",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_documents.delete_document_endpoint,
|
v1_documents.delete_document_endpoint,
|
||||||
document_service=services["document_service"],
|
document_service=services["document_service"],
|
||||||
|
|
@ -1407,14 +1446,14 @@ async def create_app():
|
||||||
# Settings endpoints
|
# Settings endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/settings",
|
"/v1/settings",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(v1_settings.get_settings_endpoint)
|
partial(v1_settings.get_settings_endpoint)
|
||||||
),
|
),
|
||||||
methods=["GET"],
|
methods=["GET"],
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/settings",
|
"/v1/settings",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_settings.update_settings_endpoint,
|
v1_settings.update_settings_endpoint,
|
||||||
session_manager=services["session_manager"],
|
session_manager=services["session_manager"],
|
||||||
|
|
@ -1425,7 +1464,7 @@ async def create_app():
|
||||||
# Knowledge filters endpoints
|
# Knowledge filters endpoints
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters",
|
"/v1/knowledge-filters",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.create_endpoint,
|
v1_knowledge_filters.create_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1436,7 +1475,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/search",
|
"/v1/knowledge-filters/search",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.search_endpoint,
|
v1_knowledge_filters.search_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1447,7 +1486,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.get_endpoint,
|
v1_knowledge_filters.get_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1458,7 +1497,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.update_endpoint,
|
v1_knowledge_filters.update_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
@ -1469,7 +1508,7 @@ async def create_app():
|
||||||
),
|
),
|
||||||
Route(
|
Route(
|
||||||
"/v1/knowledge-filters/{filter_id}",
|
"/v1/knowledge-filters/{filter_id}",
|
||||||
require_api_key(services["api_key_service"])(
|
require_api_key(services["api_key_service"], services["session_manager"])(
|
||||||
partial(
|
partial(
|
||||||
v1_knowledge_filters.delete_endpoint,
|
v1_knowledge_filters.delete_endpoint,
|
||||||
knowledge_filter_service=services["knowledge_filter_service"],
|
knowledge_filter_service=services["knowledge_filter_service"],
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,7 @@ class TaskProcessor:
|
||||||
connector_type: str = "local",
|
connector_type: str = "local",
|
||||||
embedding_model: str = None,
|
embedding_model: str = None,
|
||||||
is_sample_data: bool = False,
|
is_sample_data: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Standard processing pipeline for non-Langflow processors:
|
Standard processing pipeline for non-Langflow processors:
|
||||||
|
|
@ -166,6 +167,8 @@ class TaskProcessor:
|
||||||
Args:
|
Args:
|
||||||
embedding_model: Embedding model to use (defaults to the current
|
embedding_model: Embedding model to use (defaults to the current
|
||||||
embedding model from settings)
|
embedding model from settings)
|
||||||
|
allowed_groups: List of groups that can access this document (RBAC).
|
||||||
|
Empty list or None means no group restrictions.
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
from config.settings import INDEX_NAME, clients, get_embedding_model
|
from config.settings import INDEX_NAME, clients, get_embedding_model
|
||||||
|
|
@ -259,6 +262,11 @@ class TaskProcessor:
|
||||||
if owner_email is not None:
|
if owner_email is not None:
|
||||||
chunk_doc["owner_email"] = owner_email
|
chunk_doc["owner_email"] = owner_email
|
||||||
|
|
||||||
|
# RBAC: Set allowed groups for access control
|
||||||
|
# If allowed_groups is provided and non-empty, store it for DLS filtering
|
||||||
|
if allowed_groups:
|
||||||
|
chunk_doc["allowed_groups"] = allowed_groups
|
||||||
|
|
||||||
# Mark as sample data if specified
|
# Mark as sample data if specified
|
||||||
if is_sample_data:
|
if is_sample_data:
|
||||||
chunk_doc["is_sample_data"] = "true"
|
chunk_doc["is_sample_data"] = "true"
|
||||||
|
|
@ -309,6 +317,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
is_sample_data: bool = False,
|
is_sample_data: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__(document_service)
|
super().__init__(document_service)
|
||||||
self.owner_user_id = owner_user_id
|
self.owner_user_id = owner_user_id
|
||||||
|
|
@ -316,6 +325,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
self.is_sample_data = is_sample_data
|
self.is_sample_data = is_sample_data
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -351,6 +361,7 @@ class DocumentFileProcessor(TaskProcessor):
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
connector_type="local",
|
connector_type="local",
|
||||||
is_sample_data=self.is_sample_data,
|
is_sample_data=self.is_sample_data,
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_task.status = TaskStatus.COMPLETED
|
file_task.status = TaskStatus.COMPLETED
|
||||||
|
|
@ -382,6 +393,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
document_service=None,
|
document_service=None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__(document_service=document_service)
|
super().__init__(document_service=document_service)
|
||||||
self.connector_service = connector_service
|
self.connector_service = connector_service
|
||||||
|
|
@ -391,6 +403,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -445,6 +458,7 @@ class ConnectorFileProcessor(TaskProcessor):
|
||||||
owner_email=self.owner_email,
|
owner_email=self.owner_email,
|
||||||
file_size=len(document.content),
|
file_size=len(document.content),
|
||||||
connector_type=connection.connector_type,
|
connector_type=connection.connector_type,
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add connector-specific metadata
|
# Add connector-specific metadata
|
||||||
|
|
@ -478,6 +492,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.langflow_connector_service = langflow_connector_service
|
self.langflow_connector_service = langflow_connector_service
|
||||||
|
|
@ -487,6 +502,7 @@ class LangflowConnectorFileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -580,6 +596,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
owner_name: str = None,
|
owner_name: str = None,
|
||||||
owner_email: str = None,
|
owner_email: str = None,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
|
@ -590,6 +607,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
self.owner_name = owner_name
|
self.owner_name = owner_name
|
||||||
self.owner_email = owner_email
|
self.owner_email = owner_email
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -638,6 +656,7 @@ class S3FileProcessor(TaskProcessor):
|
||||||
owner_email=self.owner_email,
|
owner_email=self.owner_email,
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
connector_type="s3",
|
connector_type="s3",
|
||||||
|
allowed_groups=self.allowed_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
result["path"] = f"s3://{self.bucket}/{item}"
|
result["path"] = f"s3://{self.bucket}/{item}"
|
||||||
|
|
@ -669,6 +688,7 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
settings: dict = None,
|
settings: dict = None,
|
||||||
delete_after_ingest: bool = True,
|
delete_after_ingest: bool = True,
|
||||||
replace_duplicates: bool = False,
|
replace_duplicates: bool = False,
|
||||||
|
allowed_groups: list = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.langflow_file_service = langflow_file_service
|
self.langflow_file_service = langflow_file_service
|
||||||
|
|
@ -682,6 +702,7 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.delete_after_ingest = delete_after_ingest
|
self.delete_after_ingest = delete_after_ingest
|
||||||
self.replace_duplicates = replace_duplicates
|
self.replace_duplicates = replace_duplicates
|
||||||
|
self.allowed_groups = allowed_groups
|
||||||
|
|
||||||
async def process_item(
|
async def process_item(
|
||||||
self, upload_task: UploadTask, item: str, file_task: FileTask
|
self, upload_task: UploadTask, item: str, file_task: FileTask
|
||||||
|
|
@ -765,6 +786,10 @@ class LangflowFileProcessor(TaskProcessor):
|
||||||
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
|
metadata_tweaks.append({"key": "owner_email", "value": self.owner_email})
|
||||||
# Mark as local upload for connector_type
|
# Mark as local upload for connector_type
|
||||||
metadata_tweaks.append({"key": "connector_type", "value": "local"})
|
metadata_tweaks.append({"key": "connector_type", "value": "local"})
|
||||||
|
# RBAC: Add allowed_groups for access control
|
||||||
|
if self.allowed_groups:
|
||||||
|
# Store as comma-separated string for Langflow metadata
|
||||||
|
metadata_tweaks.append({"key": "allowed_groups", "value": ",".join(self.allowed_groups)})
|
||||||
|
|
||||||
if metadata_tweaks:
|
if metadata_tweaks:
|
||||||
# Initialize the OpenSearch component tweaks if not already present
|
# Initialize the OpenSearch component tweaks if not already present
|
||||||
|
|
|
||||||
|
|
@ -52,15 +52,19 @@ class APIKeyService:
|
||||||
user_email: str,
|
user_email: str,
|
||||||
name: str,
|
name: str,
|
||||||
jwt_token: str = None,
|
jwt_token: str = None,
|
||||||
|
roles: List[str] = None,
|
||||||
|
groups: List[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Create a new API key for a user.
|
Create a new API key for a user with optional RBAC restrictions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
user_email: The user's email
|
user_email: The user's email
|
||||||
name: A friendly name for the key
|
name: A friendly name for the key
|
||||||
jwt_token: JWT token for OpenSearch authentication
|
jwt_token: JWT token for OpenSearch authentication
|
||||||
|
roles: Optional list of roles to restrict this key to
|
||||||
|
groups: Optional list of groups this key can access
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with success status, key info, and the full key (only shown once)
|
Dict with success status, key info, and the full key (only shown once)
|
||||||
|
|
@ -74,6 +78,14 @@ class APIKeyService:
|
||||||
|
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
# Default roles if not specified
|
||||||
|
if roles is None:
|
||||||
|
roles = ["openrag_user"]
|
||||||
|
|
||||||
|
# Default groups to empty if not specified
|
||||||
|
if groups is None:
|
||||||
|
groups = []
|
||||||
|
|
||||||
# Create the document to store
|
# Create the document to store
|
||||||
key_doc = {
|
key_doc = {
|
||||||
"key_id": key_id,
|
"key_id": key_id,
|
||||||
|
|
@ -85,6 +97,9 @@ class APIKeyService:
|
||||||
"created_at": now,
|
"created_at": now,
|
||||||
"last_used_at": None,
|
"last_used_at": None,
|
||||||
"revoked": False,
|
"revoked": False,
|
||||||
|
# RBAC fields
|
||||||
|
"roles": roles,
|
||||||
|
"groups": groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get OpenSearch client
|
# Get OpenSearch client
|
||||||
|
|
@ -105,6 +120,8 @@ class APIKeyService:
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
key_id=key_id,
|
key_id=key_id,
|
||||||
key_prefix=key_prefix,
|
key_prefix=key_prefix,
|
||||||
|
roles=roles,
|
||||||
|
groups=groups,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -112,6 +129,8 @@ class APIKeyService:
|
||||||
"key_prefix": key_prefix,
|
"key_prefix": key_prefix,
|
||||||
"name": name,
|
"name": name,
|
||||||
"created_at": now,
|
"created_at": now,
|
||||||
|
"roles": roles,
|
||||||
|
"groups": groups,
|
||||||
"api_key": full_key, # Only returned once!
|
"api_key": full_key, # Only returned once!
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
|
|
@ -123,13 +142,13 @@ class APIKeyService:
|
||||||
|
|
||||||
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
async def validate_key(self, api_key: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Validate an API key and return user info if valid.
|
Validate an API key and return user info with RBAC claims if valid.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: The full API key to validate
|
api_key: The full API key to validate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with user info if valid, None if invalid
|
Dict with user info including roles and groups if valid, None if invalid
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Check key format
|
# Check key format
|
||||||
|
|
@ -181,11 +200,15 @@ class APIKeyService:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Don't fail validation if update fails
|
pass # Don't fail validation if update fails
|
||||||
|
|
||||||
|
# Return user info with RBAC claims
|
||||||
return {
|
return {
|
||||||
"key_id": key_doc["key_id"],
|
"key_id": key_doc["key_id"],
|
||||||
"user_id": key_doc["user_id"],
|
"user_id": key_doc["user_id"],
|
||||||
"user_email": key_doc["user_email"],
|
"user_email": key_doc["user_email"],
|
||||||
"name": key_doc["name"],
|
"name": key_doc["name"],
|
||||||
|
# RBAC fields - provide defaults for backward compatibility
|
||||||
|
"roles": key_doc.get("roles", ["openrag_user"]),
|
||||||
|
"groups": key_doc.get("groups", []),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -225,6 +248,8 @@ class APIKeyService:
|
||||||
"created_at",
|
"created_at",
|
||||||
"last_used_at",
|
"last_used_at",
|
||||||
"revoked",
|
"revoked",
|
||||||
|
"roles",
|
||||||
|
"groups",
|
||||||
],
|
],
|
||||||
"size": 100,
|
"size": 100,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ class ChatService:
|
||||||
previous_response_id: str = None,
|
previous_response_id: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
filter_id: str = None,
|
filter_id: str = None,
|
||||||
|
groups: list = None,
|
||||||
|
roles: list = None,
|
||||||
):
|
):
|
||||||
"""Handle chat requests using the patched OpenAI client"""
|
"""Handle chat requests using the patched OpenAI client"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
|
@ -23,7 +25,7 @@ class ChatService:
|
||||||
|
|
||||||
# Set authentication context for this request so tools can access it
|
# Set authentication context for this request so tools can access it
|
||||||
if user_id and jwt_token:
|
if user_id and jwt_token:
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return async_chat_stream(
|
return async_chat_stream(
|
||||||
|
|
@ -54,6 +56,8 @@ class ChatService:
|
||||||
previous_response_id: str = None,
|
previous_response_id: str = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
filter_id: str = None,
|
filter_id: str = None,
|
||||||
|
groups: list = None,
|
||||||
|
roles: list = None,
|
||||||
):
|
):
|
||||||
"""Handle Langflow chat requests"""
|
"""Handle Langflow chat requests"""
|
||||||
if not prompt:
|
if not prompt:
|
||||||
|
|
@ -346,9 +350,9 @@ class ChatService:
|
||||||
previous_response_id=previous_response_id,
|
previous_response_id=previous_response_id,
|
||||||
)
|
)
|
||||||
else: # chat
|
else: # chat
|
||||||
# Set auth context for chat tools and provide user_id
|
# Set auth context for chat tools and provide user_id with RBAC
|
||||||
if user_id and jwt_token:
|
if user_id and jwt_token:
|
||||||
set_auth_context(user_id, jwt_token)
|
set_auth_context(user_id, jwt_token, groups=groups, roles=roles)
|
||||||
response_text, response_id = await async_chat(
|
response_text, response_id = await async_chat(
|
||||||
clients.patched_llm_client,
|
clients.patched_llm_client,
|
||||||
document_prompt,
|
document_prompt,
|
||||||
|
|
|
||||||
219
src/services/group_service.py
Normal file
219
src/services/group_service.py
Normal file
|
|
@ -0,0 +1,219 @@
|
||||||
|
"""
|
||||||
|
Group Service for managing user groups for RBAC.
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from config.settings import GROUPS_INDEX_NAME
|
||||||
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupService:
|
||||||
|
"""Service for managing user groups for RBAC."""
|
||||||
|
|
||||||
|
def __init__(self, session_manager=None):
|
||||||
|
self.session_manager = session_manager
|
||||||
|
|
||||||
|
async def _ensure_index_exists(self, opensearch_client) -> None:
|
||||||
|
"""Ensure the groups index exists."""
|
||||||
|
from config.settings import GROUPS_INDEX_BODY
|
||||||
|
|
||||||
|
try:
|
||||||
|
exists = await opensearch_client.indices.exists(index=GROUPS_INDEX_NAME)
|
||||||
|
if not exists:
|
||||||
|
await opensearch_client.indices.create(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
body=GROUPS_INDEX_BODY,
|
||||||
|
)
|
||||||
|
logger.info(f"Created groups index: {GROUPS_INDEX_NAME}")
|
||||||
|
except Exception as e:
|
||||||
|
# Index might already exist from concurrent creation
|
||||||
|
if "resource_already_exists_exception" not in str(e):
|
||||||
|
logger.error(f"Failed to create groups index: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create_group(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str = "",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create a new user group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The group name (must be unique)
|
||||||
|
description: Optional description of the group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with success status and group info
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get OpenSearch client
|
||||||
|
from config.settings import clients
|
||||||
|
|
||||||
|
opensearch_client = clients.opensearch
|
||||||
|
|
||||||
|
# Ensure index exists
|
||||||
|
await self._ensure_index_exists(opensearch_client)
|
||||||
|
|
||||||
|
# Check if group with this name already exists
|
||||||
|
search_body = {
|
||||||
|
"query": {"term": {"name": name}},
|
||||||
|
"size": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await opensearch_client.search(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
body=search_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("hits", {}).get("hits", []):
|
||||||
|
return {"success": False, "error": f"Group '{name}' already exists"}
|
||||||
|
|
||||||
|
# Create a unique group_id
|
||||||
|
group_id = secrets.token_urlsafe(16)
|
||||||
|
now = datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
# Create the document to store
|
||||||
|
group_doc = {
|
||||||
|
"group_id": group_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Index the group document
|
||||||
|
result = await opensearch_client.index(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
id=group_id,
|
||||||
|
body=group_doc,
|
||||||
|
refresh="wait_for",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("result") in ("created", "updated"):
|
||||||
|
logger.info(f"Created group: {name} (id: {group_id})")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"group_id": group_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Failed to create group"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create group: {e}")
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def list_groups(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
List all user groups.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with list of groups
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get OpenSearch client
|
||||||
|
from config.settings import clients
|
||||||
|
|
||||||
|
opensearch_client = clients.opensearch
|
||||||
|
|
||||||
|
# Ensure index exists
|
||||||
|
await self._ensure_index_exists(opensearch_client)
|
||||||
|
|
||||||
|
# Search for all groups
|
||||||
|
search_body = {
|
||||||
|
"query": {"match_all": {}},
|
||||||
|
"sort": [{"name": {"order": "asc"}}],
|
||||||
|
"_source": ["group_id", "name", "description", "created_at"],
|
||||||
|
"size": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await opensearch_client.search(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
body=search_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
for hit in result.get("hits", {}).get("hits", []):
|
||||||
|
groups.append(hit["_source"])
|
||||||
|
|
||||||
|
return {"success": True, "groups": groups}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list groups: {e}")
|
||||||
|
return {"success": False, "error": str(e), "groups": []}
|
||||||
|
|
||||||
|
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get a group by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: The group ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Group info if found, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get OpenSearch client
|
||||||
|
from config.settings import clients
|
||||||
|
|
||||||
|
opensearch_client = clients.opensearch
|
||||||
|
|
||||||
|
doc = await opensearch_client.get(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
id=group_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return doc["_source"]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_group(self, group_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Delete a user group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group_id: The group ID to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with success status
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get OpenSearch client
|
||||||
|
from config.settings import clients
|
||||||
|
|
||||||
|
opensearch_client = clients.opensearch
|
||||||
|
|
||||||
|
# Verify the group exists
|
||||||
|
try:
|
||||||
|
doc = await opensearch_client.get(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
id=group_id,
|
||||||
|
)
|
||||||
|
group_name = doc["_source"].get("name", "unknown")
|
||||||
|
except Exception:
|
||||||
|
return {"success": False, "error": "Group not found"}
|
||||||
|
|
||||||
|
# Delete the group
|
||||||
|
result = await opensearch_client.delete(
|
||||||
|
index=GROUPS_INDEX_NAME,
|
||||||
|
id=group_id,
|
||||||
|
refresh="wait_for",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("result") == "deleted":
|
||||||
|
logger.info(f"Deleted group: {group_name} (id: {group_id})")
|
||||||
|
return {"success": True}
|
||||||
|
else:
|
||||||
|
return {"success": False, "error": "Failed to delete group"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete group: {e}")
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
@ -2,7 +2,7 @@ import copy
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from agentd.tool_decorator import tool
|
from agentd.tool_decorator import tool
|
||||||
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
|
from config.settings import EMBED_MODEL, clients, INDEX_NAME, get_embedding_model, WATSONX_EMBEDDING_DIMENSIONS
|
||||||
from auth_context import get_auth_context
|
from auth_context import get_auth_context, get_current_user_groups
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
@ -259,8 +259,24 @@ class SearchService:
|
||||||
# Build query body
|
# Build query body
|
||||||
if is_wildcard_match_all:
|
if is_wildcard_match_all:
|
||||||
# Match all documents; still allow filters to narrow scope
|
# Match all documents; still allow filters to narrow scope
|
||||||
if filter_clauses:
|
# Also add RBAC group filter for wildcard queries
|
||||||
query_block = {"bool": {"filter": filter_clauses}}
|
wildcard_filters = list(filter_clauses) # Copy existing filters
|
||||||
|
|
||||||
|
user_groups = get_current_user_groups()
|
||||||
|
if user_groups:
|
||||||
|
groups_access_filter = {
|
||||||
|
"bool": {
|
||||||
|
"should": [
|
||||||
|
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||||
|
{"terms": {"allowed_groups": user_groups}},
|
||||||
|
],
|
||||||
|
"minimum_should_match": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wildcard_filters.append(groups_access_filter)
|
||||||
|
|
||||||
|
if wildcard_filters:
|
||||||
|
query_block = {"bool": {"filter": wildcard_filters}}
|
||||||
else:
|
else:
|
||||||
query_block = {"match_all": {}}
|
query_block = {"match_all": {}}
|
||||||
else:
|
else:
|
||||||
|
|
@ -292,6 +308,29 @@ class SearchService:
|
||||||
# Add exists filter to existing filters
|
# Add exists filter to existing filters
|
||||||
all_filters = [*filter_clauses, exists_any_embedding]
|
all_filters = [*filter_clauses, exists_any_embedding]
|
||||||
|
|
||||||
|
# RBAC: Add group-based access control filter (fallback if DLS isn't configured)
|
||||||
|
# Documents are accessible if:
|
||||||
|
# 1. No allowed_groups field exists (backward compatibility, open access)
|
||||||
|
# 2. User's groups match any of the document's allowed_groups
|
||||||
|
user_groups = get_current_user_groups()
|
||||||
|
if user_groups:
|
||||||
|
groups_access_filter = {
|
||||||
|
"bool": {
|
||||||
|
"should": [
|
||||||
|
# Document has no allowed_groups restriction
|
||||||
|
{"bool": {"must_not": {"exists": {"field": "allowed_groups"}}}},
|
||||||
|
# User's groups match document's allowed_groups
|
||||||
|
{"terms": {"allowed_groups": user_groups}},
|
||||||
|
],
|
||||||
|
"minimum_should_match": 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
all_filters.append(groups_access_filter)
|
||||||
|
logger.debug(
|
||||||
|
"Added RBAC group filter",
|
||||||
|
user_groups=user_groups,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Building hybrid query with filters",
|
"Building hybrid query with filters",
|
||||||
user_filters_count=len(filter_clauses),
|
user_filters_count=len(filter_clauses),
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import json
|
||||||
import jwt
|
import jwt
|
||||||
import httpx
|
import httpx
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any, List
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, field
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
import os
|
import os
|
||||||
from utils.logging_config import get_logger
|
from utils.logging_config import get_logger
|
||||||
|
|
@ -24,12 +24,19 @@ class User:
|
||||||
provider: str = "google"
|
provider: str = "google"
|
||||||
created_at: datetime = None
|
created_at: datetime = None
|
||||||
last_login: datetime = None
|
last_login: datetime = None
|
||||||
|
# RBAC fields
|
||||||
|
roles: List[str] = field(default_factory=lambda: ["openrag_user"])
|
||||||
|
groups: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.created_at is None:
|
if self.created_at is None:
|
||||||
self.created_at = datetime.now()
|
self.created_at = datetime.now()
|
||||||
if self.last_login is None:
|
if self.last_login is None:
|
||||||
self.last_login = datetime.now()
|
self.last_login = datetime.now()
|
||||||
|
if self.roles is None:
|
||||||
|
self.roles = ["openrag_user"]
|
||||||
|
if self.groups is None:
|
||||||
|
self.groups = []
|
||||||
|
|
||||||
class AnonymousUser(User):
|
class AnonymousUser(User):
|
||||||
"""Anonymous user"""
|
"""Anonymous user"""
|
||||||
|
|
@ -136,11 +143,31 @@ class SessionManager:
|
||||||
# Create JWT token using the shared method
|
# Create JWT token using the shared method
|
||||||
return self.create_jwt_token(user)
|
return self.create_jwt_token(user)
|
||||||
|
|
||||||
def create_jwt_token(self, user: User) -> str:
|
def create_jwt_token(
|
||||||
"""Create JWT token for an existing user"""
|
self,
|
||||||
|
user: User,
|
||||||
|
roles: Optional[List[str]] = None,
|
||||||
|
groups: Optional[List[str]] = None,
|
||||||
|
expiration_days: int = 7,
|
||||||
|
) -> str:
|
||||||
|
"""Create JWT token for an existing user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The User object to create a token for
|
||||||
|
roles: Optional roles override (for restricted API key tokens)
|
||||||
|
groups: Optional groups override (for restricted API key tokens)
|
||||||
|
expiration_days: Token expiration in days (default 7)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Encoded JWT token string
|
||||||
|
"""
|
||||||
# Use OpenSearch-compatible issuer for OIDC validation
|
# Use OpenSearch-compatible issuer for OIDC validation
|
||||||
oidc_issuer = "http://openrag-backend:8000"
|
oidc_issuer = "http://openrag-backend:8000"
|
||||||
|
|
||||||
|
# Use provided roles/groups or fall back to user's defaults
|
||||||
|
effective_roles = roles if roles is not None else user.roles
|
||||||
|
effective_groups = groups if groups is not None else user.groups
|
||||||
|
|
||||||
# Create JWT token with OIDC-compliant claims
|
# Create JWT token with OIDC-compliant claims
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
token_payload = {
|
token_payload = {
|
||||||
|
|
@ -148,7 +175,7 @@ class SessionManager:
|
||||||
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
"iss": oidc_issuer, # Fixed issuer for OpenSearch OIDC
|
||||||
"sub": user.user_id, # Subject (user ID)
|
"sub": user.user_id, # Subject (user ID)
|
||||||
"aud": ["opensearch", "openrag"], # Audience
|
"aud": ["opensearch", "openrag"], # Audience
|
||||||
"exp": now + timedelta(days=7), # Expiration
|
"exp": now + timedelta(days=expiration_days), # Expiration
|
||||||
"iat": now, # Issued at
|
"iat": now, # Issued at
|
||||||
"auth_time": int(now.timestamp()), # Authentication time
|
"auth_time": int(now.timestamp()), # Authentication time
|
||||||
# Custom claims
|
# Custom claims
|
||||||
|
|
@ -157,7 +184,9 @@ class SessionManager:
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"preferred_username": user.email,
|
"preferred_username": user.email,
|
||||||
"email_verified": True,
|
"email_verified": True,
|
||||||
"roles": ["openrag_user"], # Backend role for OpenSearch
|
# RBAC claims
|
||||||
|
"roles": effective_roles, # Backend roles for OpenSearch/tools
|
||||||
|
"groups": effective_groups, # Group-based access control
|
||||||
}
|
}
|
||||||
|
|
||||||
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
token = jwt.encode(token_payload, self.private_key, algorithm="RS256")
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue