Compare commits
5 commits
main
...
feature/to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f214f5dbc5 | ||
|
|
8f66ab5587 | ||
|
|
6a6b48c2c9 | ||
|
|
e26c21670a | ||
|
|
8590194806 |
3 changed files with 39 additions and 8 deletions
|
|
@ -14,7 +14,7 @@ const fetchMessages = () => {
|
||||||
.then(response => response.json());
|
.then(response => response.json());
|
||||||
};
|
};
|
||||||
|
|
||||||
const sendMessage = (message: string, searchType: string) => {
|
const sendMessage = (message: string, searchType: string, topK: number = 10) => {
|
||||||
return fetch("/v1/search/", {
|
return fetch("/v1/search/", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
|
|
@ -24,6 +24,7 @@ const sendMessage = (message: string, searchType: string) => {
|
||||||
query: message,
|
query: message,
|
||||||
searchType,
|
searchType,
|
||||||
datasets: ["main_dataset"],
|
datasets: ["main_dataset"],
|
||||||
|
top_k: topK,
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.then(response => response.json());
|
.then(response => response.json());
|
||||||
|
|
@ -45,7 +46,7 @@ export default function useChat(dataset: Dataset) {
|
||||||
return setMessages(data);
|
return setMessages(data);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleMessageSending = useCallback((message: string, searchType: string) => {
|
const handleMessageSending = useCallback((message: string, searchType: string, topK: number = 10) => {
|
||||||
const sentMessageId = v4();
|
const sentMessageId = v4();
|
||||||
|
|
||||||
setMessages((messages) => [
|
setMessages((messages) => [
|
||||||
|
|
@ -59,7 +60,7 @@ export default function useChat(dataset: Dataset) {
|
||||||
|
|
||||||
disableSearchRun();
|
disableSearchRun();
|
||||||
|
|
||||||
return sendMessage(message, searchType)
|
return sendMessage(message, searchType, topK)
|
||||||
.then(newMessages => {
|
.then(newMessages => {
|
||||||
setMessages((messages) => [
|
setMessages((messages) => [
|
||||||
...messages,
|
...messages,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import classNames from "classnames";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import { LoadingIndicator } from "@/ui/App";
|
import { LoadingIndicator } from "@/ui/App";
|
||||||
import { CTAButton, Select, TextArea } from "@/ui/elements";
|
import { CTAButton, Select, TextArea, Input } from "@/ui/elements";
|
||||||
import useChat from "@/modules/chat/hooks/useChat";
|
import useChat from "@/modules/chat/hooks/useChat";
|
||||||
|
|
||||||
import styles from "./SearchView.module.css";
|
import styles from "./SearchView.module.css";
|
||||||
|
|
@ -59,17 +59,28 @@ export default function SearchView() {
|
||||||
}, [refreshChat, scrollToBottom]);
|
}, [refreshChat, scrollToBottom]);
|
||||||
|
|
||||||
const [searchInputValue, setSearchInputValue] = useState("");
|
const [searchInputValue, setSearchInputValue] = useState("");
|
||||||
|
// Add state for top_k
|
||||||
|
const [topK, setTopK] = useState(10);
|
||||||
|
|
||||||
const handleSearchInputChange = useCallback((value: string) => {
|
const handleSearchInputChange = useCallback((value: string) => {
|
||||||
setSearchInputValue(value);
|
setSearchInputValue(value);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
// Add handler for top_k input
|
||||||
|
const handleTopKChange = useCallback((e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
let value = parseInt(e.target.value, 10);
|
||||||
|
if (isNaN(value)) value = 10;
|
||||||
|
if (value < 1) value = 1;
|
||||||
|
if (value > 100) value = 100;
|
||||||
|
setTopK(value);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const handleChatMessageSubmit = useCallback((event: React.FormEvent<SearchFormPayload>) => {
|
const handleChatMessageSubmit = useCallback((event: React.FormEvent<SearchFormPayload>) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
const formElements = event.currentTarget;
|
const formElements = event.currentTarget;
|
||||||
|
|
||||||
const searchType = formElements.searchType.value;
|
const searchType = formElements.searchType.value;
|
||||||
|
|
||||||
const chatInput = searchInputValue.trim();
|
const chatInput = searchInputValue.trim();
|
||||||
|
|
||||||
if (chatInput === "") {
|
if (chatInput === "") {
|
||||||
|
|
@ -79,10 +90,11 @@ export default function SearchView() {
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
|
|
||||||
setSearchInputValue("");
|
setSearchInputValue("");
|
||||||
|
|
||||||
sendMessage(chatInput, searchType)
|
// Pass topK to sendMessage
|
||||||
|
sendMessage(chatInput, searchType, topK)
|
||||||
.then(scrollToBottom)
|
.then(scrollToBottom)
|
||||||
}, [scrollToBottom, sendMessage, searchInputValue]);
|
}, [scrollToBottom, sendMessage, searchInputValue, topK]);
|
||||||
|
|
||||||
const chatFormRef = useRef<HTMLFormElement>(null);
|
const chatFormRef = useRef<HTMLFormElement>(null);
|
||||||
|
|
||||||
|
|
@ -132,6 +144,20 @@ export default function SearchView() {
|
||||||
<option key={option.value} value={option.value}>{option.label}</option>
|
<option key={option.value} value={option.value}>{option.label}</option>
|
||||||
))}
|
))}
|
||||||
</Select>
|
</Select>
|
||||||
|
{/* Add top_k input here */}
|
||||||
|
<label className="text-gray-600 whitespace-nowrap" title="Controls how many results to return. Smaller = focused, larger = broader graph exploration.">
|
||||||
|
Max results:
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
name="topK"
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
value={topK}
|
||||||
|
onChange={handleTopKChange}
|
||||||
|
className="w-20 ml-2"
|
||||||
|
title="Controls how many results to return. Smaller = focused, larger = broader graph exploration."
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<CTAButton disabled={isSearchRunning} type="submit">
|
<CTAButton disabled={isSearchRunning} type="submit">
|
||||||
{isSearchRunning? "Searching..." : "Search"}
|
{isSearchRunning? "Searching..." : "Search"}
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,10 @@ def get_search_router() -> APIRouter:
|
||||||
top_k=payload.top_k,
|
top_k=payload.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure response conforms to response_model=list
|
||||||
|
if not isinstance(results, list):
|
||||||
|
results = [results]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
except PermissionDeniedError:
|
except PermissionDeniedError:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue