feat: add llm config
This commit is contained in:
parent
9bb30bc43a
commit
84c0c8cab5
30 changed files with 9034 additions and 214 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -169,3 +169,4 @@ cognee/cache/
|
||||||
|
|
||||||
# Default cognee system directory, used in development
|
# Default cognee system directory, used in development
|
||||||
.cognee_system/
|
.cognee_system/
|
||||||
|
.data_storage/
|
||||||
|
|
|
||||||
120
cognee-frontend/package-lock.json
generated
120
cognee-frontend/package-lock.json
generated
|
|
@ -228,6 +228,126 @@
|
||||||
"node": ">= 10"
|
"node": ">= 10"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@next/swc-darwin-x64": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"darwin"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-linux-arm64-musl": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-linux-x64-gnu": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-linux-x64-musl": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"linux"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==",
|
||||||
|
"cpu": [
|
||||||
|
"arm64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-win32-ia32-msvc": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==",
|
||||||
|
"cpu": [
|
||||||
|
"ia32"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"node_modules/@next/swc-win32-x64-msvc": {
|
||||||
|
"version": "14.2.3",
|
||||||
|
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.3.tgz",
|
||||||
|
"integrity": "sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==",
|
||||||
|
"cpu": [
|
||||||
|
"x64"
|
||||||
|
],
|
||||||
|
"optional": true,
|
||||||
|
"os": [
|
||||||
|
"win32"
|
||||||
|
],
|
||||||
|
"engines": {
|
||||||
|
"node": ">= 10"
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@nodelib/fs.scandir": {
|
"node_modules/@nodelib/fs.scandir": {
|
||||||
"version": "2.1.5",
|
"version": "2.1.5",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,8 @@ function useDatasets() {
|
||||||
|
|
||||||
if (datasets.length > 0) {
|
if (datasets.length > 0) {
|
||||||
checkDatasetStatuses(datasets);
|
checkDatasetStatuses(datasets);
|
||||||
|
} else {
|
||||||
|
window.location.href = '/wizard';
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [checkDatasetStatuses]);
|
}, [checkDatasetStatuses]);
|
||||||
|
|
|
||||||
|
|
@ -15,18 +15,24 @@
|
||||||
flex: 1;
|
flex: 1;
|
||||||
padding: 16px;
|
padding: 16px;
|
||||||
border-top: 2px solid white;
|
border-top: 2px solid white;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.messagesContainer {
|
||||||
|
flex: 1;
|
||||||
|
overflow-y: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
.messages {
|
.messages {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
padding-top: 24px;
|
padding-top: 24px;
|
||||||
padding-bottom: 24px;
|
padding-bottom: 24px;
|
||||||
|
overflow-y: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
.message {
|
.message {
|
||||||
padding: 16px;
|
padding: 16px;
|
||||||
border-radius: var(--border-radius);
|
border-radius: var(--border-radius);
|
||||||
width: max-content;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.userMessage {
|
.userMessage {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text } from 'ohmy-ui';
|
import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text, DropdownSelect } from 'ohmy-ui';
|
||||||
import styles from './SearchView.module.css';
|
import styles from './SearchView.module.css';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { v4 } from 'uuid';
|
import { v4 } from 'uuid';
|
||||||
|
|
@ -22,10 +22,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
|
||||||
setInputValue(event.target.value);
|
setInputValue(event.target.value);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const searchOptions = [{
|
||||||
|
value: 'SIMILARITY',
|
||||||
|
label: 'Similarity',
|
||||||
|
}, {
|
||||||
|
value: 'NEIGHBOR',
|
||||||
|
label: 'Neighbor',
|
||||||
|
}, {
|
||||||
|
value: 'SUMMARY',
|
||||||
|
label: 'Summary',
|
||||||
|
}, {
|
||||||
|
value: 'ADJACENT',
|
||||||
|
label: 'Adjacent',
|
||||||
|
}, {
|
||||||
|
value: 'CATEGORIES',
|
||||||
|
label: 'Categories',
|
||||||
|
}];
|
||||||
|
const [searchType, setSearchType] = useState(searchOptions[0]);
|
||||||
|
|
||||||
const handleSearchSubmit = useCallback((event: React.FormEvent<HTMLFormElement>) => {
|
const handleSearchSubmit = useCallback((event: React.FormEvent<HTMLFormElement>) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
|
|
||||||
setMessages((currentMessages) => [
|
setMessages((currentMessages) => [
|
||||||
...currentMessages,
|
...currentMessages,
|
||||||
{
|
{
|
||||||
|
|
@ -43,6 +60,7 @@ export default function SearchView({ onClose }: SearchViewProps) {
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
query_params: {
|
query_params: {
|
||||||
query: inputValue,
|
query: inputValue,
|
||||||
|
searchType: searchType.value,
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
|
|
@ -58,8 +76,8 @@ export default function SearchView({ onClose }: SearchViewProps) {
|
||||||
]);
|
]);
|
||||||
setInputValue('');
|
setInputValue('');
|
||||||
})
|
})
|
||||||
}, [inputValue]);
|
}, [inputValue, searchType]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Stack className={styles.searchViewContainer}>
|
<Stack className={styles.searchViewContainer}>
|
||||||
<Stack gap="between" align="center/" orientation="horizontal">
|
<Stack gap="between" align="center/" orientation="horizontal">
|
||||||
|
|
@ -71,20 +89,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
|
||||||
</GhostButton>
|
</GhostButton>
|
||||||
</Stack>
|
</Stack>
|
||||||
<Stack className={styles.searchContainer}>
|
<Stack className={styles.searchContainer}>
|
||||||
<Stack gap="2" className={styles.messages} align="end">
|
<div className={styles.messagesContainer}>
|
||||||
{messages.map((message) => (
|
<Stack gap="2" className={styles.messages} align="end">
|
||||||
<Text
|
{messages.map((message) => (
|
||||||
key={message.id}
|
<Text
|
||||||
className={classNames(styles.message, {
|
key={message.id}
|
||||||
[styles.userMessage]: message.user === "user",
|
className={classNames(styles.message, {
|
||||||
})}
|
[styles.userMessage]: message.user === "user",
|
||||||
>
|
})}
|
||||||
{message.text}
|
>
|
||||||
</Text>
|
{message.text}
|
||||||
))}
|
</Text>
|
||||||
</Stack>
|
))}
|
||||||
|
</Stack>
|
||||||
|
</div>
|
||||||
<form onSubmit={handleSearchSubmit}>
|
<form onSubmit={handleSearchSubmit}>
|
||||||
<Stack orientation="horizontal" gap="2">
|
<Stack orientation="horizontal" gap="2">
|
||||||
|
<DropdownSelect
|
||||||
|
value={searchType}
|
||||||
|
options={searchOptions}
|
||||||
|
onChange={setSearchType}
|
||||||
|
/>
|
||||||
<Input value={inputValue} onChange={handleInputChange} name="searchInput" placeholder="Search" />
|
<Input value={inputValue} onChange={handleInputChange} name="searchInput" placeholder="Search" />
|
||||||
<CTAButton type="submit">Search</CTAButton>
|
<CTAButton type="submit">Search</CTAButton>
|
||||||
</Stack>
|
</Stack>
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,22 @@ interface SelectOption {
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
const [llmConfig, setLLMConfig] = useState<{ openAIApiKey: string }>();
|
const [llmConfig, setLLMConfig] = useState<{
|
||||||
|
apiKey: string;
|
||||||
|
model: SelectOption;
|
||||||
|
models: {
|
||||||
|
openai: SelectOption[];
|
||||||
|
ollama: SelectOption[];
|
||||||
|
anthropic: SelectOption[];
|
||||||
|
};
|
||||||
|
provider: SelectOption;
|
||||||
|
providers: SelectOption[];
|
||||||
|
}>();
|
||||||
const [vectorDBConfig, setVectorDBConfig] = useState<{
|
const [vectorDBConfig, setVectorDBConfig] = useState<{
|
||||||
choice: SelectOption;
|
|
||||||
options: SelectOption[];
|
|
||||||
url: string;
|
url: string;
|
||||||
apiKey: string;
|
apiKey: string;
|
||||||
|
provider: SelectOption;
|
||||||
|
options: SelectOption[];
|
||||||
}>();
|
}>();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
|
@ -23,10 +33,18 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
|
|
||||||
const saveConfig = (event: React.FormEvent<HTMLFormElement>) => {
|
const saveConfig = (event: React.FormEvent<HTMLFormElement>) => {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
const newOpenAIApiKey = event.target.openAIApiKey.value;
|
|
||||||
const newVectorDBChoice = vectorDBConfig?.choice.value;
|
const newVectorConfig = {
|
||||||
const newVectorDBUrl = event.target.vectorDBUrl.value;
|
provider: vectorDBConfig?.provider.value,
|
||||||
const newVectorDBApiKey = event.target.vectorDBApiKey.value;
|
url: event.target.vectorDBUrl.value,
|
||||||
|
apiKey: event.target.vectorDBApiKey.value,
|
||||||
|
};
|
||||||
|
|
||||||
|
const newLLMConfig = {
|
||||||
|
provider: llmConfig?.provider.value,
|
||||||
|
model: llmConfig?.model.value,
|
||||||
|
apiKey: event.target.llmApiKey.value,
|
||||||
|
};
|
||||||
|
|
||||||
startSaving();
|
startSaving();
|
||||||
|
|
||||||
|
|
@ -36,14 +54,8 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
llm: {
|
llm: newLLMConfig,
|
||||||
openAIApiKey: newOpenAIApiKey,
|
vectorDB: newVectorConfig,
|
||||||
},
|
|
||||||
vectorDB: {
|
|
||||||
choice: newVectorDBChoice,
|
|
||||||
url: newVectorDBUrl,
|
|
||||||
apiKey: newVectorDBApiKey,
|
|
||||||
},
|
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
.then(() => {
|
.then(() => {
|
||||||
|
|
@ -52,12 +64,12 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
.finally(() => stopSaving());
|
.finally(() => stopSaving());
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleVectorDBChange = useCallback((newChoice: SelectOption) => {
|
const handleVectorDBChange = useCallback((newVectorDBProvider: SelectOption) => {
|
||||||
setVectorDBConfig((config) => {
|
setVectorDBConfig((config) => {
|
||||||
if (config?.choice !== newChoice) {
|
if (config?.provider !== newVectorDBProvider) {
|
||||||
return {
|
return {
|
||||||
...config,
|
...config,
|
||||||
choice: newChoice,
|
provider: newVectorDBProvider,
|
||||||
url: '',
|
url: '',
|
||||||
apiKey: '',
|
apiKey: '',
|
||||||
};
|
};
|
||||||
|
|
@ -66,11 +78,40 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const handleLLMProviderChange = useCallback((newLLMProvider: SelectOption) => {
|
||||||
|
setLLMConfig((config) => {
|
||||||
|
if (config?.provider !== newLLMProvider) {
|
||||||
|
return {
|
||||||
|
...config,
|
||||||
|
provider: newLLMProvider,
|
||||||
|
model: config?.models[newLLMProvider.value][0],
|
||||||
|
apiKey: '',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const handleLLMModelChange = useCallback((newLLMModel: SelectOption) => {
|
||||||
|
setLLMConfig((config) => {
|
||||||
|
if (config?.model !== newLLMModel) {
|
||||||
|
return {
|
||||||
|
...config,
|
||||||
|
model: newLLMModel,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const fetchVectorDBChoices = async () => {
|
const fetchVectorDBChoices = async () => {
|
||||||
const response = await fetch('http://0.0.0.0:8000/settings');
|
const response = await fetch('http://0.0.0.0:8000/settings');
|
||||||
const settings = await response.json();
|
const settings = await response.json();
|
||||||
|
|
||||||
|
if (!settings.llm.model) {
|
||||||
|
settings.llm.model = settings.llm.models[settings.llm.provider.value][0];
|
||||||
|
}
|
||||||
setLLMConfig(settings.llm);
|
setLLMConfig(settings.llm);
|
||||||
setVectorDBConfig(settings.vectorDB);
|
setVectorDBConfig(settings.vectorDB);
|
||||||
};
|
};
|
||||||
|
|
@ -79,40 +120,54 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Modal isOpen={isOpen} onClose={onClose}>
|
<Modal isOpen={isOpen} onClose={onClose}>
|
||||||
<Stack gap="4" orientation="vertical" align="center/">
|
<Stack gap="8" orientation="vertical" align="center/">
|
||||||
<H2>Settings</H2>
|
<H2>Settings</H2>
|
||||||
<form onSubmit={saveConfig} style={{ width: '100%' }}>
|
<form onSubmit={saveConfig} style={{ width: '100%' }}>
|
||||||
<Stack gap="2" orientation="vertical">
|
<Stack gap="4" orientation="vertical">
|
||||||
<H3>LLM Config</H3>
|
<Stack gap="2" orientation="vertical">
|
||||||
<FormGroup orientation="vertical" align="center/" gap="1">
|
<H3>LLM Config</H3>
|
||||||
<FormLabel>OpenAI API Key</FormLabel>
|
<FormGroup orientation="horizontal" align="center/" gap="4">
|
||||||
|
<FormLabel>LLM provider:</FormLabel>
|
||||||
|
<DropdownSelect
|
||||||
|
value={llmConfig?.provider}
|
||||||
|
options={llmConfig?.providers}
|
||||||
|
onChange={handleLLMProviderChange}
|
||||||
|
/>
|
||||||
|
</FormGroup>
|
||||||
|
<FormGroup orientation="horizontal" align="center/" gap="4">
|
||||||
|
<FormLabel>LLM model:</FormLabel>
|
||||||
|
<DropdownSelect
|
||||||
|
value={llmConfig?.model}
|
||||||
|
options={llmConfig?.provider ? llmConfig?.models[llmConfig?.provider.value] : []}
|
||||||
|
onChange={handleLLMModelChange}
|
||||||
|
/>
|
||||||
|
</FormGroup>
|
||||||
<FormInput>
|
<FormInput>
|
||||||
<Input defaultValue={llmConfig?.openAIApiKey} name="openAIApiKey" placeholder="OpenAI API Key" />
|
<Input defaultValue={llmConfig?.apiKey} name="llmApiKey" placeholder="LLM API key" />
|
||||||
</FormInput>
|
</FormInput>
|
||||||
</FormGroup>
|
</Stack>
|
||||||
|
|
||||||
<H3>Vector Database Config</H3>
|
<Stack gap="2" orientation="vertical">
|
||||||
<DropdownSelect
|
<H3>Vector Database Config</H3>
|
||||||
value={vectorDBConfig?.choice}
|
<FormGroup orientation="horizontal" align="center/" gap="4">
|
||||||
options={vectorDBConfig?.options}
|
<FormLabel>Vector DB provider:</FormLabel>
|
||||||
onChange={handleVectorDBChange}
|
<DropdownSelect
|
||||||
/>
|
value={vectorDBConfig?.provider}
|
||||||
<FormGroup orientation="vertical" align="center/" gap="1">
|
options={vectorDBConfig?.options}
|
||||||
<FormLabel>Vector DB url</FormLabel>
|
onChange={handleVectorDBChange}
|
||||||
|
/>
|
||||||
|
</FormGroup>
|
||||||
<FormInput>
|
<FormInput>
|
||||||
<Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB API url" />
|
<Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB instance url" />
|
||||||
</FormInput>
|
</FormInput>
|
||||||
</FormGroup>
|
|
||||||
<FormGroup orientation="vertical" align="center/" gap="1">
|
|
||||||
<FormLabel>Vector DB API key</FormLabel>
|
|
||||||
<FormInput>
|
<FormInput>
|
||||||
<Input defaultValue={vectorDBConfig?.apiKey} name="vectorDBApiKey" placeholder="Vector DB API key" />
|
<Input defaultValue={vectorDBConfig?.apiKey} name="vectorDBApiKey" placeholder="Vector DB API key" />
|
||||||
</FormInput>
|
</FormInput>
|
||||||
</FormGroup>
|
<Stack align="/end">
|
||||||
<Stack align="/end">
|
<Spacer top="2">
|
||||||
<Spacer top="2">
|
<CTAButton type="submit">Save</CTAButton>
|
||||||
<CTAButton type="submit">Save</CTAButton>
|
</Spacer>
|
||||||
</Spacer>
|
</Stack>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Stack>
|
</Stack>
|
||||||
</form>
|
</form>
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import os
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import json
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
# Set up logging
|
# Set up logging
|
||||||
|
|
@ -16,7 +17,7 @@ from cognee.config import Config
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
||||||
from typing import Dict, Any, List, Union, Annotated, Literal
|
from typing import Dict, Any, List, Union, Annotated, Literal, Optional
|
||||||
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
|
||||||
from fastapi.responses import JSONResponse, FileResponse
|
from fastapi.responses import JSONResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
@ -25,6 +26,7 @@ from pydantic import BaseModel
|
||||||
app = FastAPI(debug=True)
|
app = FastAPI(debug=True)
|
||||||
|
|
||||||
origins = [
|
origins = [
|
||||||
|
"http://frontend:3000",
|
||||||
"http://localhost:3000",
|
"http://localhost:3000",
|
||||||
"http://localhost:3001",
|
"http://localhost:3001",
|
||||||
]
|
]
|
||||||
|
|
@ -220,8 +222,16 @@ async def search(payload: SearchPayload):
|
||||||
from cognee import search as cognee_search
|
from cognee import search as cognee_search
|
||||||
|
|
||||||
try:
|
try:
|
||||||
search_type = "SIMILARITY"
|
search_type = payload.query_params["searchType"]
|
||||||
await cognee_search(search_type, payload.query_params)
|
params = {
|
||||||
|
"query": payload.query_params["query"],
|
||||||
|
}
|
||||||
|
results = await cognee_search(search_type, params)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code = 200,
|
||||||
|
content = json.dumps(results)
|
||||||
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code = 409,
|
status_code = 409,
|
||||||
|
|
@ -236,25 +246,26 @@ async def get_settings():
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
openAIApiKey: str
|
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
|
||||||
|
model: str
|
||||||
|
apiKey: str
|
||||||
|
|
||||||
class VectorDBConfig(BaseModel):
|
class VectorDBConfig(BaseModel):
|
||||||
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
|
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
|
||||||
url: str
|
url: str
|
||||||
apiKey: str
|
apiKey: str
|
||||||
|
|
||||||
class SettingsPayload(BaseModel):
|
class SettingsPayload(BaseModel):
|
||||||
llm: LLMConfig | None = None
|
llm: Optional[LLMConfig] = None
|
||||||
vectorDB: VectorDBConfig | None = None
|
vectorDB: Optional[VectorDBConfig] = None
|
||||||
|
|
||||||
@app.post("/settings", response_model=dict)
|
@app.post("/settings", response_model=dict)
|
||||||
async def save_config(new_settings: SettingsPayload):
|
async def save_config(new_settings: SettingsPayload):
|
||||||
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
from cognee.modules.settings import save_llm_config, save_vector_db_config
|
||||||
|
if new_settings.llm is not None:
|
||||||
if hasattr(new_settings, "llm"):
|
|
||||||
await save_llm_config(new_settings.llm)
|
await save_llm_config(new_settings.llm)
|
||||||
|
|
||||||
if hasattr(new_settings, "vectorDB"):
|
if new_settings.vectorDB is not None:
|
||||||
await save_vector_db_config(new_settings.vectorDB)
|
await save_vector_db_config(new_settings.vectorDB)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,13 @@ from dotenv import load_dotenv
|
||||||
from cognee.root_dir import get_absolute_path
|
from cognee.root_dir import get_absolute_path
|
||||||
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
|
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
|
||||||
|
|
||||||
base_dir = Path(__file__).resolve().parent.parent
|
def load_dontenv():
|
||||||
# Load the .env file from the base directory
|
base_dir = Path(__file__).resolve().parent.parent
|
||||||
dotenv_path = base_dir / ".env"
|
# Load the .env file from the base directory
|
||||||
load_dotenv(dotenv_path=dotenv_path)
|
dotenv_path = base_dir / ".env"
|
||||||
|
load_dotenv(dotenv_path=dotenv_path, override = True)
|
||||||
|
|
||||||
|
load_dontenv()
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
|
|
@ -50,16 +52,20 @@ class Config:
|
||||||
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
|
llm_provider: str = os.getenv("LLM_PROVIDER", "openai") #openai, or custom or ollama
|
||||||
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
llm_model: str = os.getenv("LLM_MODEL", None)
|
||||||
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
llm_api_key: str = os.getenv("LLM_API_KEY", None)
|
||||||
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
|
llm_endpoint: str = os.getenv("LLM_ENDPOINT", None)
|
||||||
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
|
|
||||||
ollama_key: Optional[str] = "ollama"
|
# custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
|
# custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
||||||
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
|
# custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
|
||||||
model_endpoint: str = "openai"
|
# ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
|
||||||
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
# ollama_key: Optional[str] = "ollama"
|
||||||
|
# ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
|
||||||
|
# openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
|
||||||
|
# model_endpoint: str = "openai"
|
||||||
|
# llm_api_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||||
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
||||||
openai_embedding_model = "text-embedding-3-large"
|
openai_embedding_model = "text-embedding-3-large"
|
||||||
openai_embedding_dimensions = 3072
|
openai_embedding_dimensions = 3072
|
||||||
|
|
@ -132,6 +138,7 @@ class Config:
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Loads the configuration from a file or environment variables."""
|
"""Loads the configuration from a file or environment variables."""
|
||||||
|
load_dontenv()
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read(self.config_path)
|
config.read(self.config_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from .databases.relational import DuckDBAdapter, DatabaseEngine
|
||||||
from .databases.vector.vector_db_interface import VectorDBInterface
|
from .databases.vector.vector_db_interface import VectorDBInterface
|
||||||
from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
|
from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
|
||||||
from .llm.llm_interface import LLMInterface
|
from .llm.llm_interface import LLMInterface
|
||||||
from .llm.openai.adapter import OpenAIAdapter
|
from .llm.get_llm_client import get_llm_client
|
||||||
from .files.storage import LocalStorage
|
from .files.storage import LocalStorage
|
||||||
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
|
||||||
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
|
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
|
||||||
|
|
@ -35,6 +35,10 @@ class InfrastructureConfig():
|
||||||
chunk_engine = None
|
chunk_engine = None
|
||||||
graph_topology = config.graph_topology
|
graph_topology = config.graph_topology
|
||||||
monitoring_tool = config.monitoring_tool
|
monitoring_tool = config.monitoring_tool
|
||||||
|
llm_provider: str = None
|
||||||
|
llm_model: str = None
|
||||||
|
llm_endpoint: str = None
|
||||||
|
llm_api_key: str = None
|
||||||
|
|
||||||
def get_config(self, config_entity: str = None) -> dict:
|
def get_config(self, config_entity: str = None) -> dict:
|
||||||
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
|
||||||
|
|
@ -84,7 +88,8 @@ class InfrastructureConfig():
|
||||||
self.graph_topology = config.graph_topology
|
self.graph_topology = config.graph_topology
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
|
||||||
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
|
self.llm_engine = get_llm_client()
|
||||||
|
|
||||||
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
|
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
|
||||||
self.database_directory_path = self.system_root_directory + "/" + config.db_path
|
self.database_directory_path = self.system_root_directory + "/" + config.db_path
|
||||||
|
|
||||||
|
|
@ -115,8 +120,8 @@ class InfrastructureConfig():
|
||||||
from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
|
from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
|
||||||
|
|
||||||
self.vector_engine = QDrantAdapter(
|
self.vector_engine = QDrantAdapter(
|
||||||
qdrant_url = config.qdrant_url,
|
url = config.qdrant_url,
|
||||||
qdrant_api_key = config.qdrant_api_key,
|
api_key = config.qdrant_api_key,
|
||||||
embedding_engine = self.embedding_engine
|
embedding_engine = self.embedding_engine
|
||||||
)
|
)
|
||||||
self.vector_engine_choice = "qdrant"
|
self.vector_engine_choice = "qdrant"
|
||||||
|
|
@ -127,11 +132,10 @@ class InfrastructureConfig():
|
||||||
LocalStorage.ensure_directory_exists(lance_db_path)
|
LocalStorage.ensure_directory_exists(lance_db_path)
|
||||||
|
|
||||||
self.vector_engine = LanceDBAdapter(
|
self.vector_engine = LanceDBAdapter(
|
||||||
uri = lance_db_path,
|
url = lance_db_path,
|
||||||
api_key = None,
|
api_key = None,
|
||||||
embedding_engine = self.embedding_engine,
|
embedding_engine = self.embedding_engine,
|
||||||
)
|
)
|
||||||
self.lance_db_path = lance_db_path
|
|
||||||
self.vector_engine_choice = "lancedb"
|
self.vector_engine_choice = "lancedb"
|
||||||
|
|
||||||
if config_entity is not None:
|
if config_entity is not None:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
from typing import List, Optional, get_type_hints, Generic, TypeVar
|
||||||
import asyncio
|
import asyncio
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.pydantic import Vector, LanceModel
|
from lancedb.pydantic import Vector, LanceModel
|
||||||
from cognee.infrastructure.files.storage import LocalStorage
|
from cognee.infrastructure.files.storage import LocalStorage
|
||||||
|
|
@ -9,21 +8,25 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
class LanceDBAdapter(VectorDBInterface):
|
class LanceDBAdapter(VectorDBInterface):
|
||||||
|
name = "LanceDB"
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
connection: lancedb.AsyncConnection = None
|
connection: lancedb.AsyncConnection = None
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uri: Optional[str],
|
url: Optional[str],
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine,
|
embedding_engine: EmbeddingEngine,
|
||||||
):
|
):
|
||||||
self.uri = uri
|
self.url = url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
async def get_connection(self):
|
async def get_connection(self):
|
||||||
if self.connection is None:
|
if self.connection is None:
|
||||||
self.connection = await lancedb.connect_async(self.uri, api_key = self.api_key)
|
self.connection = await lancedb.connect_async(self.url, api_key = self.api_key)
|
||||||
|
|
||||||
return self.connection
|
return self.connection
|
||||||
|
|
||||||
|
|
@ -35,12 +38,12 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
class LanceDataPoint(LanceModel):
|
class LanceDataPoint(LanceModel):
|
||||||
id: data_point_types["id"] = Field(...)
|
id: data_point_types["id"]
|
||||||
vector: Vector(vector_size)
|
vector: Vector(vector_size)
|
||||||
payload: payload_schema
|
payload: payload_schema
|
||||||
|
|
||||||
|
|
@ -128,7 +131,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
query_texts: List[str],
|
query_texts: List[str],
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vector: bool = False,
|
with_vectors: bool = False,
|
||||||
):
|
):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
|
|
@ -137,11 +140,11 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
collection_name = collection_name,
|
collection_name = collection_name,
|
||||||
query_vector = query_vector,
|
query_vector = query_vector,
|
||||||
limit = limit,
|
limit = limit,
|
||||||
with_vector = with_vector,
|
with_vector = with_vectors,
|
||||||
) for query_vector in query_vectors]
|
) for query_vector in query_vectors]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
if self.uri.startswith("/"):
|
if self.url.startswith("/"):
|
||||||
LocalStorage.remove_all(self.uri) # Remove the temporary directory and files inside
|
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
|
||||||
|
|
|
||||||
|
|
@ -26,29 +26,29 @@ def create_quantization_config(quantization_config: Dict):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
class QDrantAdapter(VectorDBInterface):
|
class QDrantAdapter(VectorDBInterface):
|
||||||
qdrant_url: str = None
|
name = "Qdrant"
|
||||||
|
url: str = None
|
||||||
|
api_key: str = None
|
||||||
qdrant_path: str = None
|
qdrant_path: str = None
|
||||||
qdrant_api_key: str = None
|
|
||||||
|
|
||||||
def __init__(self, qdrant_url, qdrant_api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
|
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
if qdrant_path is not None:
|
if qdrant_path is not None:
|
||||||
self.qdrant_path = qdrant_path
|
self.qdrant_path = qdrant_path
|
||||||
else:
|
else:
|
||||||
self.qdrant_url = qdrant_url
|
self.url = url
|
||||||
|
self.api_key = api_key
|
||||||
self.qdrant_api_key = qdrant_api_key
|
|
||||||
|
|
||||||
def get_qdrant_client(self) -> AsyncQdrantClient:
|
def get_qdrant_client(self) -> AsyncQdrantClient:
|
||||||
if self.qdrant_path is not None:
|
if self.qdrant_path is not None:
|
||||||
return AsyncQdrantClient(
|
return AsyncQdrantClient(
|
||||||
path = self.qdrant_path, port=6333
|
path = self.qdrant_path, port=6333
|
||||||
)
|
)
|
||||||
elif self.qdrant_url is not None:
|
elif self.url is not None:
|
||||||
return AsyncQdrantClient(
|
return AsyncQdrantClient(
|
||||||
url = self.qdrant_url,
|
url = self.url,
|
||||||
api_key = self.qdrant_api_key,
|
api_key = self.api_key,
|
||||||
port = 6333
|
port = 6333
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from multiprocessing import Pool
|
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
from ..models.DataPoint import DataPoint
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -9,19 +8,24 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
|
|
||||||
|
|
||||||
class WeaviateAdapter(VectorDBInterface):
|
class WeaviateAdapter(VectorDBInterface):
|
||||||
async_pool: Pool = None
|
name = "Weaviate"
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
embedding_engine: EmbeddingEngine = None
|
embedding_engine: EmbeddingEngine = None
|
||||||
|
|
||||||
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
|
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
|
||||||
import weaviate
|
import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
|
|
||||||
|
self.url = url
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
self.client = weaviate.connect_to_wcs(
|
self.client = weaviate.connect_to_wcs(
|
||||||
cluster_url=url,
|
cluster_url = url,
|
||||||
auth_credentials=weaviate.auth.AuthApiKey(api_key),
|
auth_credentials = weaviate.auth.AuthApiKey(api_key),
|
||||||
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30))
|
additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init=30))
|
||||||
)
|
)
|
||||||
|
|
||||||
async def embed_data(self, data: List[str]) -> List[float]:
|
async def embed_data(self, data: List[str]) -> List[float]:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .config import llm_config
|
||||||
|
|
@ -8,7 +8,9 @@ from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
class AnthropicAdapter(LLMInterface):
|
class AnthropicAdapter(LLMInterface):
|
||||||
"""Adapter for Ollama's API"""
|
"""Adapter for Anthropic API"""
|
||||||
|
name = "Anthropic"
|
||||||
|
model: str
|
||||||
|
|
||||||
def __init__(self, model: str = None):
|
def __init__(self, model: str = None):
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
|
|
|
||||||
16
cognee/infrastructure/llm/config.py
Normal file
16
cognee/infrastructure/llm/config.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
|
||||||
|
class LLMConfig():
|
||||||
|
llm_provider: str = None
|
||||||
|
llm_model: str = None
|
||||||
|
llm_endpoint: str = None
|
||||||
|
llm_api_key: str = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"provider": self.llm_provider,
|
||||||
|
"model": self.llm_model,
|
||||||
|
"endpoint": self.llm_endpoint,
|
||||||
|
"apiKey": self.llm_api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_config = LLMConfig()
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import instructor
|
import instructor
|
||||||
from tenacity import retry, stop_after_attempt
|
from tenacity import retry, stop_after_attempt
|
||||||
from openai import AsyncOpenAI
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
|
|
@ -19,23 +17,31 @@ config.load()
|
||||||
if config.monitoring_tool == MonitoringTool.LANGFUSE:
|
if config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||||
from langfuse.openai import AsyncOpenAI, OpenAI
|
from langfuse.openai import AsyncOpenAI, OpenAI
|
||||||
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
|
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||||
from langsmith import wrap_openai
|
from langsmith import wrappers
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
AsyncOpenAI = wrap_openai(AsyncOpenAI())
|
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||||
else:
|
else:
|
||||||
from openai import AsyncOpenAI, OpenAI
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""Adapter for Generic API LLM provider API """
|
"""Adapter for Generic API LLM provider API """
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
|
||||||
def __init__(self, api_endpoint, api_key: str, model: str):
|
def __init__(self, api_endpoint, api_key: str, model: str, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
if infrastructure_config.get_config()["llm_provider"] == "groq":
|
||||||
if infrastructure_config.get_config()["llm_provider"] == 'groq':
|
|
||||||
from groq import groq
|
from groq import groq
|
||||||
self.aclient = instructor.from_openai(client = groq.Groq(
|
self.aclient = instructor.from_openai(
|
||||||
api_key=api_key,
|
client = groq.Groq(
|
||||||
), mode=instructor.Mode.MD_JSON)
|
api_key = api_key,
|
||||||
|
),
|
||||||
|
mode = instructor.Mode.MD_JSON
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.aclient = instructor.patch(
|
self.aclient = instructor.patch(
|
||||||
AsyncOpenAI(
|
AsyncOpenAI(
|
||||||
|
|
@ -45,9 +51,6 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
mode = instructor.Mode.JSON,
|
mode = instructor.Mode.JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
@retry(stop = stop_after_attempt(5))
|
||||||
def completions_with_backoff(self, **kwargs):
|
def completions_with_backoff(self, **kwargs):
|
||||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
"""Wrapper around ChatCompletion.create w/ backoff"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
"""Get the LLM client."""
|
"""Get the LLM client."""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from cognee.config import Config
|
import json
|
||||||
from .anthropic.adapter import AnthropicAdapter
|
import logging
|
||||||
from .openai.adapter import OpenAIAdapter
|
from cognee.infrastructure.llm import llm_config
|
||||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
|
||||||
|
|
||||||
# Define an Enum for LLM Providers
|
# Define an Enum for LLM Providers
|
||||||
class LLMProvider(Enum):
|
class LLMProvider(Enum):
|
||||||
|
|
@ -12,20 +11,22 @@ class LLMProvider(Enum):
|
||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
CUSTOM = "custom"
|
CUSTOM = "custom"
|
||||||
|
|
||||||
config = Config()
|
|
||||||
config.load()
|
|
||||||
|
|
||||||
def get_llm_client():
|
def get_llm_client():
|
||||||
"""Get the LLM client based on the configuration using Enums."""
|
"""Get the LLM client based on the configuration using Enums."""
|
||||||
provider = LLMProvider(config.llm_provider)
|
logging.error(json.dumps(llm_config.to_dict()))
|
||||||
|
provider = LLMProvider(llm_config.llm_provider)
|
||||||
|
|
||||||
if provider == LLMProvider.OPENAI:
|
if provider == LLMProvider.OPENAI:
|
||||||
return OpenAIAdapter(config.openai_key, config.openai_model)
|
from .openai.adapter import OpenAIAdapter
|
||||||
|
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model)
|
||||||
elif provider == LLMProvider.OLLAMA:
|
elif provider == LLMProvider.OLLAMA:
|
||||||
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||||
elif provider == LLMProvider.ANTHROPIC:
|
elif provider == LLMProvider.ANTHROPIC:
|
||||||
return AnthropicAdapter(config.custom_model)
|
from .anthropic.adapter import AnthropicAdapter
|
||||||
|
return AnthropicAdapter(llm_config.llm_model)
|
||||||
elif provider == LLMProvider.CUSTOM:
|
elif provider == LLMProvider.CUSTOM:
|
||||||
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
|
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||||
|
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||||
|
|
|
||||||
|
|
@ -23,12 +23,16 @@ else:
|
||||||
from openai import AsyncOpenAI, OpenAI
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
|
||||||
class OpenAIAdapter(LLMInterface):
|
class OpenAIAdapter(LLMInterface):
|
||||||
|
name = "OpenAI"
|
||||||
|
model: str
|
||||||
|
api_key: str
|
||||||
|
|
||||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||||
def __init__(self, api_key: str, model:str):
|
def __init__(self, api_key: str, model:str):
|
||||||
openai.api_key = api_key
|
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
|
||||||
self.aclient = instructor.from_openai(AsyncOpenAI())
|
self.client = instructor.from_openai(OpenAI(api_key = api_key))
|
||||||
self.client = instructor.from_openai(OpenAI())
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
@retry(stop = stop_after_attempt(5))
|
||||||
def completions_with_backoff(self, **kwargs):
|
def completions_with_backoff(self, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ def evaluate():
|
||||||
|
|
||||||
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
|
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
|
||||||
|
|
||||||
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
|
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
|
||||||
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
|
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
|
||||||
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ def evaluate():
|
||||||
return dsp.answer_match(example.answer, [answer_prediction.answer], frac = 0.8) or \
|
return dsp.answer_match(example.answer, [answer_prediction.answer], frac = 0.8) or \
|
||||||
dsp.passage_match([example.answer], [answer_prediction.answer])
|
dsp.passage_match([example.answer], [answer_prediction.answer])
|
||||||
|
|
||||||
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
|
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
|
||||||
dspy.settings.configure(lm = gpt4)
|
dspy.settings.configure(lm = gpt4)
|
||||||
|
|
||||||
evaluate_on_hotpotqa(compiled_extract_knowledge_graph, metric = evaluate_answer)
|
evaluate_on_hotpotqa(compiled_extract_knowledge_graph, metric = evaluate_answer)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
|
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
|
||||||
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
|
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
|
||||||
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ def train():
|
||||||
|
|
||||||
trainset = [example.with_inputs("context", "question") for example in train_examples]
|
trainset = [example.with_inputs("context", "question") for example in train_examples]
|
||||||
|
|
||||||
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
|
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
|
||||||
|
|
||||||
compiled_extract_knowledge_graph = optimizer.compile(ExtractKnowledgeGraph(lm = gpt4), trainset = trainset)
|
compiled_extract_knowledge_graph = optimizer.compile(ExtractKnowledgeGraph(lm = gpt4), trainset = trainset)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
|
||||||
|
|
||||||
|
|
||||||
class ExtractKnowledgeGraph(dspy.Module):
|
class ExtractKnowledgeGraph(dspy.Module):
|
||||||
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)):
|
def __init__(self, lm = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm = lm
|
self.lm = lm
|
||||||
dspy.settings.configure(lm=self.lm)
|
dspy.settings.configure(lm=self.lm)
|
||||||
|
|
@ -50,7 +50,7 @@ class ExtractKnowledgeGraph(dspy.Module):
|
||||||
|
|
||||||
def forward(self, context: str, question: str):
|
def forward(self, context: str, question: str):
|
||||||
context = remove_stop_words(context)
|
context = remove_stop_words(context)
|
||||||
context = trim_text_to_max_tokens(context, 1500, config.openai_model)
|
context = trim_text_to_max_tokens(context, 1500, config.llm_model)
|
||||||
|
|
||||||
with dspy.context(lm = self.lm):
|
with dspy.context(lm = self.lm):
|
||||||
graph = self.generate_graph(text = context).graph
|
graph = self.generate_graph(text = context).graph
|
||||||
|
|
@ -79,7 +79,7 @@ def remove_stop_words(text):
|
||||||
|
|
||||||
#
|
#
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
# gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.openai_key, model_type="chat")
|
# gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.llm_api_key, model_type="chat")
|
||||||
# dspy.settings.configure(lm=gpt_4_turbo)
|
# dspy.settings.configure(lm=gpt_4_turbo)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
|
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Find the neighbours of a given node in the graph and return their descriptions.
|
Find the neighbours of a given node in the graph and return their descriptions.
|
||||||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
||||||
|
|
@ -12,13 +12,12 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
|
||||||
Parameters:
|
Parameters:
|
||||||
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
|
||||||
- query (str): Unused in this implementation but could be used for future enhancements.
|
- query (str): Unused in this implementation but could be used for future enhancements.
|
||||||
- infrastructure_config (Dict): Configuration that includes the graph engine type.
|
|
||||||
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
|
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
|
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
|
||||||
"""
|
"""
|
||||||
node_id = other_param.get('node_id') if other_param else None
|
node_id = other_param.get('node_id') if other_param else query
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
|
|
||||||
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
|
async def search_neighbour(graph: Union[nx.Graph, any], query: str,
|
||||||
other_param: dict = None):
|
other_param: dict = None):
|
||||||
"""
|
"""
|
||||||
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
||||||
|
|
@ -23,8 +23,7 @@ async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
|
||||||
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
|
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
|
||||||
"""
|
"""
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
if node_id is None:
|
node_id = other_param.get('node_id') if other_param else query
|
||||||
node_id = other_param.get('node_id') if other_param else None
|
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,84 @@
|
||||||
from cognee.config import Config
|
from cognee.config import Config
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
|
from cognee.infrastructure.llm import llm_config
|
||||||
config = Config()
|
|
||||||
config.load()
|
|
||||||
|
|
||||||
def get_settings():
|
def get_settings():
|
||||||
vector_engine_choice = infrastructure_config.get_config()["vector_engine_choice"]
|
config = Config()
|
||||||
vector_db_options = [{
|
config.load()
|
||||||
"value": "weaviate",
|
|
||||||
"label": "Weaviate",
|
vector_dbs = [{
|
||||||
|
"value": "weaviate",
|
||||||
|
"label": "Weaviate",
|
||||||
}, {
|
}, {
|
||||||
"value": "qdrant",
|
"value": "qdrant",
|
||||||
"label": "Qdrant",
|
"label": "Qdrant",
|
||||||
}, {
|
}, {
|
||||||
"value": "lancedb",
|
"value": "lancedb",
|
||||||
"label": "LanceDB",
|
"label": "LanceDB",
|
||||||
}]
|
}]
|
||||||
|
|
||||||
vector_db_config = dict(
|
vector_engine = infrastructure_config.get_config("vector_engine")
|
||||||
url = config.weaviate_url,
|
|
||||||
apiKey = config.weaviate_api_key,
|
llm_providers = [{
|
||||||
choice = vector_db_options[0],
|
"value": "openai",
|
||||||
options = vector_db_options,
|
"label": "OpenAI",
|
||||||
) if vector_engine_choice == "weaviate" else dict(
|
}, {
|
||||||
url = config.qdrant_url,
|
"value": "ollama",
|
||||||
apiKey = config.qdrant_api_key,
|
"label": "Ollama",
|
||||||
choice = vector_db_options[1],
|
}, {
|
||||||
options = vector_db_options,
|
"value": "anthropic",
|
||||||
) if vector_engine_choice == "qdrant" else dict(
|
"label": "Anthropic",
|
||||||
url = infrastructure_config.get_config("lance_db_path"),
|
}]
|
||||||
choice = vector_db_options[2],
|
|
||||||
options = vector_db_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
llm = dict(
|
llm = {
|
||||||
openAIApiKey = config.openai_key[:-10] + "**********",
|
"provider": {
|
||||||
),
|
"label": llm_config.llm_provider,
|
||||||
vectorDB = vector_db_config,
|
"value": llm_config.llm_provider,
|
||||||
|
} if llm_config.llm_provider else llm_providers[0],
|
||||||
|
"model": {
|
||||||
|
"value": llm_config.llm_model,
|
||||||
|
"label": llm_config.llm_model,
|
||||||
|
} if llm_config.llm_model else None,
|
||||||
|
"apiKey": llm_config.llm_api_key[:-10] + "**********" if llm_config.llm_api_key else None,
|
||||||
|
"providers": llm_providers,
|
||||||
|
"models": {
|
||||||
|
"openai": [{
|
||||||
|
"value": "gpt-4o",
|
||||||
|
"label": "gpt-4o",
|
||||||
|
}, {
|
||||||
|
"value": "gpt-4-turbo",
|
||||||
|
"label": "gpt-4-turbo",
|
||||||
|
}, {
|
||||||
|
"value": "gpt-3.5-turbo",
|
||||||
|
"label": "gpt-3.5-turbo",
|
||||||
|
}],
|
||||||
|
"ollama": [{
|
||||||
|
"value": "llama3",
|
||||||
|
"label": "llama3",
|
||||||
|
}, {
|
||||||
|
"value": "mistral",
|
||||||
|
"label": "mistral",
|
||||||
|
}],
|
||||||
|
"anthropic": [{
|
||||||
|
"value": "Claude 3 Opus",
|
||||||
|
"label": "Claude 3 Opus",
|
||||||
|
}, {
|
||||||
|
"value": "Claude 3 Sonnet",
|
||||||
|
"label": "Claude 3 Sonnet",
|
||||||
|
}, {
|
||||||
|
"value": "Claude 3 Haiku",
|
||||||
|
"label": "Claude 3 Haiku",
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
vectorDB = {
|
||||||
|
"provider": {
|
||||||
|
"label": vector_engine.name,
|
||||||
|
"value": vector_engine.name.lower(),
|
||||||
|
},
|
||||||
|
"url": vector_engine.url,
|
||||||
|
"apiKey": vector_engine.api_key,
|
||||||
|
"options": vector_dbs,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,20 @@
|
||||||
import os
|
import json
|
||||||
|
import logging
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cognee.config import Config
|
from cognee.infrastructure.llm import llm_config
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
config = Config()
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
openAIApiKey: str
|
apiKey: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
async def save_llm_config(llm_config: LLMConfig):
|
async def save_llm_config(new_llm_config: LLMConfig):
|
||||||
if "*" in llm_config.openAIApiKey:
|
llm_config.llm_provider = new_llm_config.provider
|
||||||
return
|
llm_config.llm_model = new_llm_config.model
|
||||||
|
|
||||||
os.environ["OPENAI_API_KEY"] = llm_config.openAIApiKey
|
if "*****" not in new_llm_config.apiKey and len(new_llm_config.apiKey.strip()) > 0:
|
||||||
config.load()
|
llm_config.llm_api_key = new_llm_config.apiKey
|
||||||
|
|
||||||
|
logging.error(json.dumps(llm_config.to_dict()))
|
||||||
|
infrastructure_config.llm_engine = None
|
||||||
|
|
|
||||||
|
|
@ -7,24 +7,24 @@ from cognee.infrastructure import infrastructure_config
|
||||||
config = Config()
|
config = Config()
|
||||||
|
|
||||||
class VectorDBConfig(BaseModel):
|
class VectorDBConfig(BaseModel):
|
||||||
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
|
|
||||||
url: str
|
url: str
|
||||||
apiKey: str
|
apiKey: str
|
||||||
|
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
|
||||||
|
|
||||||
async def save_vector_db_config(vector_db_config: VectorDBConfig):
|
async def save_vector_db_config(vector_db_config: VectorDBConfig):
|
||||||
if vector_db_config.choice == "weaviate":
|
if vector_db_config.provider == "weaviate":
|
||||||
os.environ["WEAVIATE_URL"] = vector_db_config.url
|
os.environ["WEAVIATE_URL"] = vector_db_config.url
|
||||||
os.environ["WEAVIATE_API_KEY"] = vector_db_config.apiKey
|
os.environ["WEAVIATE_API_KEY"] = vector_db_config.apiKey
|
||||||
|
|
||||||
remove_qdrant_config()
|
remove_qdrant_config()
|
||||||
|
|
||||||
if vector_db_config.choice == "qdrant":
|
if vector_db_config.provider == "qdrant":
|
||||||
os.environ["QDRANT_URL"] = vector_db_config.url
|
os.environ["QDRANT_URL"] = vector_db_config.url
|
||||||
os.environ["QDRANT_API_KEY"] = vector_db_config.apiKey
|
os.environ["QDRANT_API_KEY"] = vector_db_config.apiKey
|
||||||
|
|
||||||
remove_weaviate_config()
|
remove_weaviate_config()
|
||||||
|
|
||||||
if vector_db_config.choice == "lancedb":
|
if vector_db_config.provider == "lancedb":
|
||||||
remove_qdrant_config()
|
remove_qdrant_config()
|
||||||
remove_weaviate_config()
|
remove_weaviate_config()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ services:
|
||||||
limits:
|
limits:
|
||||||
cpus: "4.0"
|
cpus: "4.0"
|
||||||
memory: 8GB
|
memory: 8GB
|
||||||
|
|
||||||
frontend:
|
frontend:
|
||||||
container_name: frontend
|
container_name: frontend
|
||||||
build:
|
build:
|
||||||
|
|
@ -55,7 +56,6 @@ services:
|
||||||
networks:
|
networks:
|
||||||
- cognee_backend
|
- cognee_backend
|
||||||
|
|
||||||
|
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres
|
image: postgres
|
||||||
container_name: postgres
|
container_name: postgres
|
||||||
|
|
@ -69,18 +69,7 @@ services:
|
||||||
- cognee_backend
|
- cognee_backend
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
litellm:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
args:
|
|
||||||
target: runtime
|
|
||||||
image: ghcr.io/berriai/litellm:main-latest
|
|
||||||
ports:
|
|
||||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
|
||||||
volumes:
|
|
||||||
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
|
|
||||||
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
|
|
||||||
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
|
|
||||||
falkordb:
|
falkordb:
|
||||||
image: falkordb/falkordb:edge
|
image: falkordb/falkordb:edge
|
||||||
container_name: falkordb
|
container_name: falkordb
|
||||||
|
|
|
||||||
8521
poetry.lock
generated
Normal file
8521
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -60,7 +60,6 @@ tiktoken = "^0.6.0"
|
||||||
dspy-ai = "2.4.3"
|
dspy-ai = "2.4.3"
|
||||||
posthog = "^3.5.0"
|
posthog = "^3.5.0"
|
||||||
lancedb = "^0.6.10"
|
lancedb = "^0.6.10"
|
||||||
|
|
||||||
importlib-metadata = "6.8.0"
|
importlib-metadata = "6.8.0"
|
||||||
litellm = "^1.37.3"
|
litellm = "^1.37.3"
|
||||||
groq = "^0.5.0"
|
groq = "^0.5.0"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue