Rate assistant responses (#39224)

* rate assistant responses

* test

* always show
This commit is contained in:
Saxon Fletcher
2025-10-07 08:46:30 +10:00
committed by GitHub
parent 31b6368049
commit d5bf4ef13b
10 changed files with 592 additions and 6 deletions

View File

@@ -10,6 +10,7 @@ import { LOCAL_STORAGE_KEYS, useFlag } from 'common'
import { useParams, useSearchParamsShallow } from 'common/hooks'
import { Markdown } from 'components/interfaces/Markdown'
import { useCheckOpenAIKeyQuery } from 'data/ai/check-api-key-query'
import { useRateMessageMutation } from 'data/ai/rate-message-mutation'
import { constructHeaders } from 'data/fetchers'
import { useTablesQuery } from 'data/tables/tables-query'
import { useSendEventMutation } from 'data/telemetry/send-event-mutation'
@@ -83,10 +84,13 @@ export const AIAssistant = ({ className }: AIAssistantProps) => {
const [value, setValue] = useState<string>(snap.initialInput || '')
const [editingMessageId, setEditingMessageId] = useState<string | null>(null)
const [isResubmitting, setIsResubmitting] = useState(false)
const [messageRatings, setMessageRatings] = useState<Record<string, 'positive' | 'negative'>>({})
const { data: check, isSuccess } = useCheckOpenAIKeyQuery()
const isApiKeySet = IS_PLATFORM || !!check?.hasKey
const { mutateAsync: rateMessage } = useRateMessageMutation()
const isInSQLEditor = router.pathname.includes('/sql/[id]')
const snippet = snippets[entityId ?? '']
const snippetContent = snippet?.snippet?.content?.sql
@@ -260,6 +264,47 @@ export const AIAssistant = ({ className }: AIAssistantProps) => {
setValue('')
}, [setValue])
const handleRateMessage = useCallback(
async (messageId: string, rating: 'positive' | 'negative', reason?: string) => {
if (!project?.ref || !selectedOrganization?.slug) return
// Optimistically update UI
setMessageRatings((prev) => ({ ...prev, [messageId]: rating }))
try {
const result = await rateMessage({
rating,
messages: chatMessages,
messageId,
projectRef: project.ref,
orgSlug: selectedOrganization.slug,
reason,
})
sendEvent({
action: 'assistant_message_rating_submitted',
properties: {
rating,
category: result.category,
...(reason && { reason }),
},
groups: {
project: project.ref,
organization: selectedOrganization.slug,
},
})
} catch (error) {
console.error('Failed to rate message:', error)
// Rollback on error
setMessageRatings((prev) => {
const { [messageId]: _, ...rest } = prev
return rest
})
}
},
[chatMessages, project?.ref, selectedOrganization?.slug, rateMessage, sendEvent]
)
const renderedMessages = useMemo(
() =>
chatMessages.map((message, index) => {
@@ -283,6 +328,8 @@ export const AIAssistant = ({ className }: AIAssistantProps) => {
isBeingEdited={isBeingEdited}
onCancelEdit={cancelEdit}
isLastMessage={isLastMessage}
onRate={handleRateMessage}
rating={messageRatings[message.id] ?? null}
/>
)
}),
@@ -294,6 +341,8 @@ export const AIAssistant = ({ className }: AIAssistantProps) => {
editingMessageId,
chatStatus,
addToolResult,
handleRateMessage,
messageRatings,
]
)

View File

@@ -1,13 +1,33 @@
import { Pencil, Trash2 } from 'lucide-react'
import { type PropsWithChildren } from 'react'
import { Pencil, ThumbsDown, ThumbsUp, Trash2 } from 'lucide-react'
import { type PropsWithChildren, useState, useEffect } from 'react'
import { zodResolver } from '@hookform/resolvers/zod'
import { useForm } from 'react-hook-form'
import * as z from 'zod'
import { ButtonTooltip } from '../ButtonTooltip'
import {
cn,
Button,
Popover_Shadcn_,
PopoverTrigger_Shadcn_,
PopoverContent_Shadcn_,
Form_Shadcn_,
FormField_Shadcn_,
FormControl_Shadcn_,
TextArea_Shadcn_,
} from 'ui'
import { FormItemLayout } from 'ui-patterns/form/FormItemLayout/FormItemLayout'
export function MessageActions({ children }: PropsWithChildren<{}>) {
export function MessageActions({
children,
alwaysShow = false,
}: PropsWithChildren<{ alwaysShow?: boolean }>) {
return (
<div className="flex items-center gap-4 mt-2 mb-1">
<span className="h-0.5 w-5 bg-muted" />
<div className="opacity-0 group-hover:opacity-100 transition-opacity">{children}</div>
<div className={cn('group-hover:opacity-100 transition-opacity', !alwaysShow && 'opacity-0')}>
{children}
</div>
</div>
)
}
@@ -44,3 +64,147 @@ function MessageActionsDelete({ onClick }: { onClick: () => void }) {
)
}
MessageActions.Delete = MessageActionsDelete
function MessageActionsThumbsUp({
onClick,
isActive,
disabled,
}: {
onClick: () => void
isActive?: boolean
disabled?: boolean
}) {
return (
<Button
type="text"
disabled={disabled}
icon={
<ThumbsUp
size={14}
strokeWidth={1.5}
className={cn(
isActive
? 'text-brand hover:text-brand-700'
: 'text-foreground-light hover:text-foreground'
)}
/>
}
onClick={onClick}
className={cn('p-1 rounded transition-colors', disabled && 'opacity-50 pointer-events-none')}
title="Good response"
aria-label="Good response"
/>
)
}
MessageActions.ThumbsUp = MessageActionsThumbsUp
const feedbackSchema = z.object({
reason: z.string().optional(),
})
type FeedbackFormValues = z.infer<typeof feedbackSchema>
function MessageActionsThumbsDown({
onClick,
isActive,
disabled,
}: {
onClick: (reason?: string) => void
isActive?: boolean
disabled?: boolean
}) {
const [open, setOpen] = useState(false)
const form = useForm<FeedbackFormValues>({
resolver: zodResolver(feedbackSchema),
defaultValues: { reason: '' },
mode: 'onSubmit',
})
const handleOpenChange = (newOpen: boolean) => {
if (disabled) return
// When popover closes, submit the rating if not already submitted
if (!newOpen && open && !form.formState.isSubmitSuccessful) {
onClick()
}
setOpen(newOpen)
if (!newOpen) {
form.reset()
}
}
const onSubmit = (values: FeedbackFormValues) => {
onClick(values.reason || undefined)
}
// Auto-close popover after showing thank you message
useEffect(() => {
if (form.formState.isSubmitSuccessful) {
const timer = setTimeout(() => {
setOpen(false)
}, 2000)
return () => clearTimeout(timer)
}
}, [form.formState.isSubmitSuccessful])
return (
<Popover_Shadcn_ open={open} onOpenChange={handleOpenChange}>
<PopoverTrigger_Shadcn_ asChild>
<Button
type="text"
disabled={disabled}
onClick={() => !disabled && setOpen(true)}
className={cn(
'p-1 rounded transition-colors',
disabled && 'opacity-50 pointer-events-none'
)}
title="Bad response"
aria-label="Bad response"
>
<ThumbsDown
size={14}
strokeWidth={1.5}
className={cn(
isActive
? 'text-warning hover:text-warning-700'
: 'text-foreground-light hover:text-foreground'
)}
/>
</Button>
</PopoverTrigger_Shadcn_>
<PopoverContent_Shadcn_ portal className="w-80" align="start">
{form.formState.isSubmitSuccessful ? (
<p className="text-sm">We appreciate your feedback!</p>
) : (
<Form_Shadcn_ {...form}>
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-3">
<FormField_Shadcn_
control={form.control}
name="reason"
render={({ field }) => (
<FormItemLayout label="What went wrong?" labelOptional="optional">
<FormControl_Shadcn_>
<TextArea_Shadcn_
placeholder="Describe why the response was not helpful..."
autoComplete="off"
rows={4}
autoFocus
{...field}
/>
</FormControl_Shadcn_>
</FormItemLayout>
)}
/>
<div className="flex justify-end">
<Button type="primary" htmlType="submit" size="tiny">
Submit feedback
</Button>
</div>
</form>
</Form_Shadcn_>
)}
</PopoverContent_Shadcn_>
</Popover_Shadcn_>
)
}
MessageActions.ThumbsDown = MessageActionsThumbsDown

View File

@@ -18,6 +18,7 @@ export interface MessageInfo {
isLastMessage?: boolean
state: 'idle' | 'editing' | 'predecessor-editing'
rating?: 'positive' | 'negative' | null
}
export interface MessageActions {
@@ -26,6 +27,7 @@ export interface MessageActions {
onDelete: (id: string) => void
onEdit: (id: string) => void
onCancelEdit: () => void
onRate?: (id: string, rating: 'positive' | 'negative', reason?: string) => void
}
const MessageInfoContext = createContext<MessageInfo | null>(null)

View File

@@ -10,8 +10,12 @@ import { MessageDisplay } from './Message.Display'
import { MessageProvider, useMessageActionsContext, useMessageInfoContext } from './Message.Context'
function AssistantMessage({ message }: { message: VercelMessage }) {
const { variant, state } = useMessageInfoContext()
const { onCancelEdit } = useMessageActionsContext()
const { id, variant, state, isLastMessage, readOnly, rating, isLoading } = useMessageInfoContext()
const { onCancelEdit, onRate } = useMessageActionsContext()
const handleRate = (newRating: 'positive' | 'negative', reason?: string) => {
onRate?.(id, newRating, reason)
}
return (
<MessageDisplay.Container
@@ -24,6 +28,20 @@ function AssistantMessage({ message }: { message: VercelMessage }) {
<MessageDisplay.MainArea>
<MessageDisplay.Content message={message} />
</MessageDisplay.MainArea>
{!readOnly && isLastMessage && onRate && !isLoading && (
<MessageActions alwaysShow>
<MessageActions.ThumbsUp
onClick={() => handleRate('positive')}
isActive={rating === 'positive'}
disabled={!!rating}
/>
<MessageActions.ThumbsDown
onClick={(reason) => handleRate('negative', reason)}
isActive={rating === 'negative'}
disabled={!!rating}
/>
</MessageActions>
)}
</MessageDisplay.Container>
)
}
@@ -81,6 +99,8 @@ interface MessageProps {
isBeingEdited: boolean
onCancelEdit: () => void
isLastMessage?: boolean
onRate?: (id: string, rating: 'positive' | 'negative', reason?: string) => void
rating?: 'positive' | 'negative' | null
}
export function Message(props: MessageProps) {
@@ -99,6 +119,7 @@ export function Message(props: MessageProps) {
? 'predecessor-editing'
: 'idle',
isLastMessage: props.isLastMessage,
rating: props.rating,
} satisfies MessageInfo
const messageActions = {
@@ -106,6 +127,7 @@ export function Message(props: MessageProps) {
onDelete: props.onDelete,
onEdit: props.onEdit,
onCancelEdit: props.onCancelEdit,
onRate: props.onRate,
}
return (

View File

@@ -96,3 +96,18 @@ export const deployEdgeFunctionInputSchema = z
export const deployEdgeFunctionOutputSchema = z
.object({ success: z.boolean().optional() })
.passthrough()
export const rateMessageResponseSchema = z.object({
category: z.enum([
'sql_generation',
'schema_design',
'rls_policies',
'edge_functions',
'database_optimization',
'debugging',
'general_help',
'other',
]),
})
export type RateMessageResponse = z.infer<typeof rateMessageResponseSchema>

View File

@@ -0,0 +1,74 @@
import { useMutation, UseMutationOptions } from '@tanstack/react-query'
import { UIMessage } from '@ai-sdk/react'
import { constructHeaders, fetchHandler } from 'data/fetchers'
import { BASE_PATH } from 'lib/constants'
import { ResponseError } from 'types'
import type { RateMessageResponse } from 'components/ui/AIAssistantPanel/Message.utils'
export type RateMessageVariables = {
rating: 'positive' | 'negative'
messages: UIMessage[]
messageId: string
projectRef: string
orgSlug?: string
reason?: string
}
export async function rateMessage({
rating,
messages,
messageId,
projectRef,
orgSlug,
reason,
}: RateMessageVariables) {
const url = `${BASE_PATH}/api/ai/feedback/rate`
const headers = await constructHeaders({ 'Content-Type': 'application/json' })
const response = await fetchHandler(url, {
headers,
method: 'POST',
body: JSON.stringify({ rating, messages, messageId, projectRef, orgSlug, reason }),
})
let body: any
try {
body = await response.json()
} catch {}
if (!response.ok) {
throw new ResponseError(body?.message, response.status)
}
return body as RateMessageResponse
}
type RateMessageData = Awaited<ReturnType<typeof rateMessage>>
export const useRateMessageMutation = ({
onSuccess,
onError,
...options
}: Omit<
UseMutationOptions<RateMessageData, ResponseError, RateMessageVariables>,
'mutationFn'
> = {}) => {
return useMutation<RateMessageData, ResponseError, RateMessageVariables>(
(vars) => rateMessage(vars),
{
async onSuccess(data, variables, context) {
await onSuccess?.(data, variables, context)
},
async onError(data, variables, context) {
if (onError === undefined) {
console.error(`Failed to rate message: ${data.message}`)
} else {
onError(data, variables, context)
}
},
...options,
}
)
}

View File

@@ -0,0 +1,74 @@
import { expect, test, vi } from 'vitest'
// End of third-party imports
import rate from '../../pages/api/ai/feedback/rate'
import { sanitizeMessagePart } from '../ai/tools/tool-sanitizer'
vi.mock('../ai/tools/tool-sanitizer', () => ({
sanitizeMessagePart: vi.fn((part) => part),
}))
test('rate calls the tool sanitizer', async () => {
const mockReq = {
method: 'POST',
headers: {
authorization: 'Bearer test-token',
},
body: {
rating: 'negative',
messages: [
{
role: 'assistant',
parts: [
{
type: 'tool-execute_sql',
state: 'output-available',
output: 'test output',
},
],
},
],
messageId: 'test-message-id',
projectRef: 'test-project',
orgSlug: 'test-org',
reason: 'The response was not helpful',
},
on: vi.fn(),
}
const mockRes = {
status: vi.fn(() => mockRes),
json: vi.fn(() => mockRes),
setHeader: vi.fn(() => mockRes),
}
vi.mock('lib/ai/org-ai-details', () => ({
getOrgAIDetails: vi.fn().mockResolvedValue({
aiOptInLevel: 'schema_and_log_and_data',
isLimited: false,
}),
}))
vi.mock('lib/ai/model', () => ({
getModel: vi.fn().mockResolvedValue({
model: {},
error: null,
}),
}))
vi.mock('ai', () => ({
generateObject: vi.fn().mockResolvedValue({
object: {
category: 'sql_generation',
},
}),
}))
vi.mock('components/ui/AIAssistantPanel/Message.utils', () => ({
rateMessageResponseSchema: {},
}))
await rate(mockReq as any, mockRes as any)
expect(sanitizeMessagePart).toHaveBeenCalled()
})

View File

@@ -8,6 +8,7 @@ export const config = {
// [Joshen] Return 404 for all next.js API endpoints EXCEPT the ones we use in hosted:
const HOSTED_SUPPORTED_API_URLS = [
'/ai/sql/generate-v4',
'/ai/feedback/rate',
'/ai/code/complete',
'/ai/sql/cron-v2',
'/ai/sql/title-v2',

View File

@@ -0,0 +1,155 @@
import { generateObject } from 'ai'
import { NextApiRequest, NextApiResponse } from 'next'
import { z } from 'zod'
import { IS_PLATFORM } from 'common'
import type { AiOptInLevel } from 'hooks/misc/useOrgOptedIntoAi'
import { getModel } from 'lib/ai/model'
import { getOrgAIDetails } from 'lib/ai/org-ai-details'
import { sanitizeMessagePart } from 'lib/ai/tools/tool-sanitizer'
import apiWrapper from 'lib/api/apiWrapper'
import { rateMessageResponseSchema } from 'components/ui/AIAssistantPanel/Message.utils'
export const maxDuration = 30
async function handler(req: NextApiRequest, res: NextApiResponse) {
const { method } = req
switch (method) {
case 'POST':
return handlePost(req, res)
default:
res.setHeader('Allow', ['POST'])
res.status(405).json({ data: null, error: { message: `Method ${method} Not Allowed` } })
}
}
const requestBodySchema = z.object({
rating: z.enum(['positive', 'negative']),
messages: z.array(z.any()),
messageId: z.string(),
projectRef: z.string(),
orgSlug: z.string().optional(),
reason: z.string().optional(),
})
export async function handlePost(req: NextApiRequest, res: NextApiResponse) {
const authorization = req.headers.authorization
const accessToken = authorization?.replace('Bearer ', '')
if (IS_PLATFORM && !accessToken) {
return res.status(401).json({ error: 'Authorization token is required' })
}
const body = typeof req.body === 'string' ? JSON.parse(req.body) : req.body
const { data, error: parseError } = requestBodySchema.safeParse(body)
if (parseError) {
return res.status(400).json({ error: 'Invalid request body', issues: parseError.issues })
}
const { rating, messages: rawMessages, projectRef, orgSlug, reason } = data
let aiOptInLevel: AiOptInLevel = 'disabled'
if (!IS_PLATFORM) {
aiOptInLevel = 'schema'
}
if (IS_PLATFORM && orgSlug && authorization && projectRef) {
try {
// Get organizations and compute opt in level server-side
const { aiOptInLevel: orgAIOptInLevel, isLimited: orgAILimited } = await getOrgAIDetails({
orgSlug,
authorization,
projectRef,
})
aiOptInLevel = orgAIOptInLevel
} catch (error) {
return res.status(400).json({
error: 'There was an error fetching your organization details',
})
}
}
// Only returns last 7 messages
// Filters out tool outputs based on opt-in level using sanitizeMessagePart
const messages = (rawMessages || []).slice(-7).map((msg: any) => {
if (msg && msg.role === 'assistant' && 'results' in msg) {
const cleanedMsg = { ...msg }
delete cleanedMsg.results
return cleanedMsg
}
if (msg && msg.role === 'assistant' && msg.parts) {
const cleanedParts = msg.parts.map((part: any) => {
return sanitizeMessagePart(part, aiOptInLevel)
})
return { ...msg, parts: cleanedParts }
}
return msg
})
try {
const { model, error: modelError } = await getModel({
provider: 'openai',
isLimited: true,
routingKey: 'feedback',
})
if (modelError) {
return res.status(500).json({ error: modelError.message })
}
const { object } = await generateObject({
model,
schema: rateMessageResponseSchema,
prompt: `
Your job is to look at a Supabase Assistant conversation, which the user has given feedback on, and classify it.
The user gave this feedback: ${rating === 'positive' ? 'THUMBS UP (positive)' : 'THUMBS DOWN (negative)'}
${reason ? `\nUser's reason: ${reason}` : ''}
Raw conversation:
${JSON.stringify(messages)}
Instructions:
1. Classify the conversation into ONE of these categories:
- sql_generation: Generating SQL queries, DML statements
- schema_design: Creating tables, columns, relationships
- rls_policies: Row Level Security policies
- edge_functions: Edge Functions or serverless functions
- database_optimization: Performance, indexes, optimization
- debugging: Helping debug errors or issues
- general_help: General questions about Supabase features
- other: Anything else
`,
})
return res.json({
category: object.category,
})
} catch (error) {
if (error instanceof Error) {
console.error(`Classifying feedback failed:`, error)
// Check for context length error
if (error.message.includes('context_length') || error.message.includes('too long')) {
return res.status(400).json({
error: 'The conversation is too large to analyze',
})
}
} else {
console.error(`Unknown error: ${error}`)
}
return res.status(500).json({
error: 'There was an unknown error analyzing the feedback.',
})
}
}
const wrapper = (req: NextApiRequest, res: NextApiResponse) =>
apiWrapper(req, res, handler, { withAuth: true })
export default wrapper

View File

@@ -1189,6 +1189,35 @@ export interface AiAssistantInSupportFormClickedEvent {
groups: Partial<TelemetryGroups>
}
/**
* User rated an AI assistant message with thumbs up or thumbs down.
*
* @group Events
* @source studio
*/
export interface AssistantMessageRatingSubmittedEvent {
action: 'assistant_message_rating_submitted'
properties: {
/**
* The rating given by the user: positive (thumbs up) or negative (thumbs down)
*/
rating: 'positive' | 'negative'
/**
* The category of the conversation
*/
category:
| 'sql_generation'
| 'schema_design'
| 'rls_policies'
| 'edge_functions'
| 'database_optimization'
| 'debugging'
| 'general_help'
| 'other'
}
groups: TelemetryGroups
}
/**
* User copied the command for a Supabase UI component.
*
@@ -1871,6 +1900,7 @@ export type TelemetryEvent =
| AssistantSuggestionRunQueryClickedEvent
| AssistantSqlDiffHandlerEvaluatedEvent
| AssistantEditInSqlEditorClickedEvent
| AssistantMessageRatingSubmittedEvent
| DocsFeedbackClickedEvent
| HomepageFrameworkQuickstartClickedEvent
| HomepageProductCardClickedEvent