feat: add llm config

This commit is contained in:
Boris Arzentar 2024-05-22 22:36:30 +02:00
parent 9bb30bc43a
commit 84c0c8cab5
30 changed files with 9034 additions and 214 deletions

1
.gitignore vendored
View file

@ -169,3 +169,4 @@ cognee/cache/
# Default cognee system directory, used in development # Default cognee system directory, used in development
.cognee_system/ .cognee_system/
.data_storage/

View file

@ -228,6 +228,126 @@
"node": ">= 10" "node": ">= 10"
} }
}, },
"node_modules/@next/swc-darwin-x64": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.3.tgz",
"integrity": "sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.3.tgz",
"integrity": "sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.3.tgz",
"integrity": "sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.3.tgz",
"integrity": "sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.3.tgz",
"integrity": "sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.3.tgz",
"integrity": "sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-ia32-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.3.tgz",
"integrity": "sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==",
"cpu": [
"ia32"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.3.tgz",
"integrity": "sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@nodelib/fs.scandir": { "node_modules/@nodelib/fs.scandir": {
"version": "2.1.5", "version": "2.1.5",
"dev": true, "dev": true,

View file

@ -73,6 +73,8 @@ function useDatasets() {
if (datasets.length > 0) { if (datasets.length > 0) {
checkDatasetStatuses(datasets); checkDatasetStatuses(datasets);
} else {
window.location.href = '/wizard';
} }
}); });
}, [checkDatasetStatuses]); }, [checkDatasetStatuses]);

View file

@ -15,18 +15,24 @@
flex: 1; flex: 1;
padding: 16px; padding: 16px;
border-top: 2px solid white; border-top: 2px solid white;
overflow: hidden;
}
.messagesContainer {
flex: 1;
overflow-y: auto;
} }
.messages { .messages {
flex: 1; flex: 1;
padding-top: 24px; padding-top: 24px;
padding-bottom: 24px; padding-bottom: 24px;
overflow-y: auto;
} }
.message { .message {
padding: 16px; padding: 16px;
border-radius: var(--border-radius); border-radius: var(--border-radius);
width: max-content;
} }
.userMessage { .userMessage {

View file

@ -1,4 +1,4 @@
import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text } from 'ohmy-ui'; import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text, DropdownSelect } from 'ohmy-ui';
import styles from './SearchView.module.css'; import styles from './SearchView.module.css';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { v4 } from 'uuid'; import { v4 } from 'uuid';
@ -22,10 +22,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
setInputValue(event.target.value); setInputValue(event.target.value);
}, []); }, []);
const searchOptions = [{
value: 'SIMILARITY',
label: 'Similarity',
}, {
value: 'NEIGHBOR',
label: 'Neighbor',
}, {
value: 'SUMMARY',
label: 'Summary',
}, {
value: 'ADJACENT',
label: 'Adjacent',
}, {
value: 'CATEGORIES',
label: 'Categories',
}];
const [searchType, setSearchType] = useState(searchOptions[0]);
const handleSearchSubmit = useCallback((event: React.FormEvent<HTMLFormElement>) => { const handleSearchSubmit = useCallback((event: React.FormEvent<HTMLFormElement>) => {
event.preventDefault(); event.preventDefault();
setMessages((currentMessages) => [ setMessages((currentMessages) => [
...currentMessages, ...currentMessages,
{ {
@ -43,6 +60,7 @@ export default function SearchView({ onClose }: SearchViewProps) {
body: JSON.stringify({ body: JSON.stringify({
query_params: { query_params: {
query: inputValue, query: inputValue,
searchType: searchType.value,
}, },
}), }),
}) })
@ -58,8 +76,8 @@ export default function SearchView({ onClose }: SearchViewProps) {
]); ]);
setInputValue(''); setInputValue('');
}) })
}, [inputValue]); }, [inputValue, searchType]);
return ( return (
<Stack className={styles.searchViewContainer}> <Stack className={styles.searchViewContainer}>
<Stack gap="between" align="center/" orientation="horizontal"> <Stack gap="between" align="center/" orientation="horizontal">
@ -71,20 +89,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
</GhostButton> </GhostButton>
</Stack> </Stack>
<Stack className={styles.searchContainer}> <Stack className={styles.searchContainer}>
<Stack gap="2" className={styles.messages} align="end"> <div className={styles.messagesContainer}>
{messages.map((message) => ( <Stack gap="2" className={styles.messages} align="end">
<Text {messages.map((message) => (
key={message.id} <Text
className={classNames(styles.message, { key={message.id}
[styles.userMessage]: message.user === "user", className={classNames(styles.message, {
})} [styles.userMessage]: message.user === "user",
> })}
{message.text} >
</Text> {message.text}
))} </Text>
</Stack> ))}
</Stack>
</div>
<form onSubmit={handleSearchSubmit}> <form onSubmit={handleSearchSubmit}>
<Stack orientation="horizontal" gap="2"> <Stack orientation="horizontal" gap="2">
<DropdownSelect
value={searchType}
options={searchOptions}
onChange={setSearchType}
/>
<Input value={inputValue} onChange={handleInputChange} name="searchInput" placeholder="Search" /> <Input value={inputValue} onChange={handleInputChange} name="searchInput" placeholder="Search" />
<CTAButton type="submit">Search</CTAButton> <CTAButton type="submit">Search</CTAButton>
</Stack> </Stack>

View file

@ -7,12 +7,22 @@ interface SelectOption {
} }
export default function SettingsModal({ isOpen = false, onClose = () => {} }) { export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
const [llmConfig, setLLMConfig] = useState<{ openAIApiKey: string }>(); const [llmConfig, setLLMConfig] = useState<{
apiKey: string;
model: SelectOption;
models: {
openai: SelectOption[];
ollama: SelectOption[];
anthropic: SelectOption[];
};
provider: SelectOption;
providers: SelectOption[];
}>();
const [vectorDBConfig, setVectorDBConfig] = useState<{ const [vectorDBConfig, setVectorDBConfig] = useState<{
choice: SelectOption;
options: SelectOption[];
url: string; url: string;
apiKey: string; apiKey: string;
provider: SelectOption;
options: SelectOption[];
}>(); }>();
const { const {
@ -23,10 +33,18 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
const saveConfig = (event: React.FormEvent<HTMLFormElement>) => { const saveConfig = (event: React.FormEvent<HTMLFormElement>) => {
event.preventDefault(); event.preventDefault();
const newOpenAIApiKey = event.target.openAIApiKey.value;
const newVectorDBChoice = vectorDBConfig?.choice.value; const newVectorConfig = {
const newVectorDBUrl = event.target.vectorDBUrl.value; provider: vectorDBConfig?.provider.value,
const newVectorDBApiKey = event.target.vectorDBApiKey.value; url: event.target.vectorDBUrl.value,
apiKey: event.target.vectorDBApiKey.value,
};
const newLLMConfig = {
provider: llmConfig?.provider.value,
model: llmConfig?.model.value,
apiKey: event.target.llmApiKey.value,
};
startSaving(); startSaving();
@ -36,14 +54,8 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
}, },
body: JSON.stringify({ body: JSON.stringify({
llm: { llm: newLLMConfig,
openAIApiKey: newOpenAIApiKey, vectorDB: newVectorConfig,
},
vectorDB: {
choice: newVectorDBChoice,
url: newVectorDBUrl,
apiKey: newVectorDBApiKey,
},
}), }),
}) })
.then(() => { .then(() => {
@ -52,12 +64,12 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
.finally(() => stopSaving()); .finally(() => stopSaving());
}; };
const handleVectorDBChange = useCallback((newChoice: SelectOption) => { const handleVectorDBChange = useCallback((newVectorDBProvider: SelectOption) => {
setVectorDBConfig((config) => { setVectorDBConfig((config) => {
if (config?.choice !== newChoice) { if (config?.provider !== newVectorDBProvider) {
return { return {
...config, ...config,
choice: newChoice, provider: newVectorDBProvider,
url: '', url: '',
apiKey: '', apiKey: '',
}; };
@ -66,11 +78,40 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
}); });
}, []); }, []);
const handleLLMProviderChange = useCallback((newLLMProvider: SelectOption) => {
setLLMConfig((config) => {
if (config?.provider !== newLLMProvider) {
return {
...config,
provider: newLLMProvider,
model: config?.models[newLLMProvider.value][0],
apiKey: '',
};
}
return config;
});
}, []);
const handleLLMModelChange = useCallback((newLLMModel: SelectOption) => {
setLLMConfig((config) => {
if (config?.model !== newLLMModel) {
return {
...config,
model: newLLMModel,
};
}
return config;
});
}, []);
useEffect(() => { useEffect(() => {
const fetchVectorDBChoices = async () => { const fetchVectorDBChoices = async () => {
const response = await fetch('http://0.0.0.0:8000/settings'); const response = await fetch('http://0.0.0.0:8000/settings');
const settings = await response.json(); const settings = await response.json();
if (!settings.llm.model) {
settings.llm.model = settings.llm.models[settings.llm.provider.value][0];
}
setLLMConfig(settings.llm); setLLMConfig(settings.llm);
setVectorDBConfig(settings.vectorDB); setVectorDBConfig(settings.vectorDB);
}; };
@ -79,40 +120,54 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
return ( return (
<Modal isOpen={isOpen} onClose={onClose}> <Modal isOpen={isOpen} onClose={onClose}>
<Stack gap="4" orientation="vertical" align="center/"> <Stack gap="8" orientation="vertical" align="center/">
<H2>Settings</H2> <H2>Settings</H2>
<form onSubmit={saveConfig} style={{ width: '100%' }}> <form onSubmit={saveConfig} style={{ width: '100%' }}>
<Stack gap="2" orientation="vertical"> <Stack gap="4" orientation="vertical">
<H3>LLM Config</H3> <Stack gap="2" orientation="vertical">
<FormGroup orientation="vertical" align="center/" gap="1"> <H3>LLM Config</H3>
<FormLabel>OpenAI API Key</FormLabel> <FormGroup orientation="horizontal" align="center/" gap="4">
<FormLabel>LLM provider:</FormLabel>
<DropdownSelect
value={llmConfig?.provider}
options={llmConfig?.providers}
onChange={handleLLMProviderChange}
/>
</FormGroup>
<FormGroup orientation="horizontal" align="center/" gap="4">
<FormLabel>LLM model:</FormLabel>
<DropdownSelect
value={llmConfig?.model}
options={llmConfig?.provider ? llmConfig?.models[llmConfig?.provider.value] : []}
onChange={handleLLMModelChange}
/>
</FormGroup>
<FormInput> <FormInput>
<Input defaultValue={llmConfig?.openAIApiKey} name="openAIApiKey" placeholder="OpenAI API Key" /> <Input defaultValue={llmConfig?.apiKey} name="llmApiKey" placeholder="LLM API key" />
</FormInput> </FormInput>
</FormGroup> </Stack>
<H3>Vector Database Config</H3> <Stack gap="2" orientation="vertical">
<DropdownSelect <H3>Vector Database Config</H3>
value={vectorDBConfig?.choice} <FormGroup orientation="horizontal" align="center/" gap="4">
options={vectorDBConfig?.options} <FormLabel>Vector DB provider:</FormLabel>
onChange={handleVectorDBChange} <DropdownSelect
/> value={vectorDBConfig?.provider}
<FormGroup orientation="vertical" align="center/" gap="1"> options={vectorDBConfig?.options}
<FormLabel>Vector DB url</FormLabel> onChange={handleVectorDBChange}
/>
</FormGroup>
<FormInput> <FormInput>
<Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB API url" /> <Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB instance url" />
</FormInput> </FormInput>
</FormGroup>
<FormGroup orientation="vertical" align="center/" gap="1">
<FormLabel>Vector DB API key</FormLabel>
<FormInput> <FormInput>
<Input defaultValue={vectorDBConfig?.apiKey} name="vectorDBApiKey" placeholder="Vector DB API key" /> <Input defaultValue={vectorDBConfig?.apiKey} name="vectorDBApiKey" placeholder="Vector DB API key" />
</FormInput> </FormInput>
</FormGroup> <Stack align="/end">
<Stack align="/end"> <Spacer top="2">
<Spacer top="2"> <CTAButton type="submit">Save</CTAButton>
<CTAButton type="submit">Save</CTAButton> </Spacer>
</Spacer> </Stack>
</Stack> </Stack>
</Stack> </Stack>
</form> </form>

View file

@ -3,6 +3,7 @@ import os
import aiohttp import aiohttp
import uvicorn import uvicorn
import json
import logging import logging
# Set up logging # Set up logging
@ -16,7 +17,7 @@ from cognee.config import Config
config = Config() config = Config()
config.load() config.load()
from typing import Dict, Any, List, Union, Annotated, Literal from typing import Dict, Any, List, Union, Annotated, Literal, Optional
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -25,6 +26,7 @@ from pydantic import BaseModel
app = FastAPI(debug=True) app = FastAPI(debug=True)
origins = [ origins = [
"http://frontend:3000",
"http://localhost:3000", "http://localhost:3000",
"http://localhost:3001", "http://localhost:3001",
] ]
@ -220,8 +222,16 @@ async def search(payload: SearchPayload):
from cognee import search as cognee_search from cognee import search as cognee_search
try: try:
search_type = "SIMILARITY" search_type = payload.query_params["searchType"]
await cognee_search(search_type, payload.query_params) params = {
"query": payload.query_params["query"],
}
results = await cognee_search(search_type, params)
return JSONResponse(
status_code = 200,
content = json.dumps(results)
)
except Exception as error: except Exception as error:
return JSONResponse( return JSONResponse(
status_code = 409, status_code = 409,
@ -236,25 +246,26 @@ async def get_settings():
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
openAIApiKey: str provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str
apiKey: str
class VectorDBConfig(BaseModel): class VectorDBConfig(BaseModel):
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]] provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
url: str url: str
apiKey: str apiKey: str
class SettingsPayload(BaseModel): class SettingsPayload(BaseModel):
llm: LLMConfig | None = None llm: Optional[LLMConfig] = None
vectorDB: VectorDBConfig | None = None vectorDB: Optional[VectorDBConfig] = None
@app.post("/settings", response_model=dict) @app.post("/settings", response_model=dict)
async def save_config(new_settings: SettingsPayload): async def save_config(new_settings: SettingsPayload):
from cognee.modules.settings import save_llm_config, save_vector_db_config from cognee.modules.settings import save_llm_config, save_vector_db_config
if new_settings.llm is not None:
if hasattr(new_settings, "llm"):
await save_llm_config(new_settings.llm) await save_llm_config(new_settings.llm)
if hasattr(new_settings, "vectorDB"): if new_settings.vectorDB is not None:
await save_vector_db_config(new_settings.vectorDB) await save_vector_db_config(new_settings.vectorDB)
return JSONResponse( return JSONResponse(

View file

@ -10,11 +10,13 @@ from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
base_dir = Path(__file__).resolve().parent.parent def load_dontenv():
# Load the .env file from the base directory base_dir = Path(__file__).resolve().parent.parent
dotenv_path = base_dir / ".env" # Load the .env file from the base directory
load_dotenv(dotenv_path=dotenv_path) dotenv_path = base_dir / ".env"
load_dotenv(dotenv_path=dotenv_path, override = True)
load_dontenv()
@dataclass @dataclass
class Config: class Config:
@ -50,16 +52,20 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl") graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters # Model parameters
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama llm_provider: str = os.getenv("LLM_PROVIDER", "openai") #openai, or custom or ollama
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1" llm_model: str = os.getenv("LLM_MODEL", None)
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint llm_api_key: str = os.getenv("LLM_API_KEY", None)
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY") llm_endpoint: str = os.getenv("LLM_ENDPOINT", None)
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama" # custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct" # custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o" # custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
model_endpoint: str = "openai" # ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY") # ollama_key: Optional[str] = "ollama"
# ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
# openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
# model_endpoint: str = "openai"
# llm_api_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0)) openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
openai_embedding_model = "text-embedding-3-large" openai_embedding_model = "text-embedding-3-large"
openai_embedding_dimensions = 3072 openai_embedding_dimensions = 3072
@ -132,6 +138,7 @@ class Config:
def load(self): def load(self):
"""Loads the configuration from a file or environment variables.""" """Loads the configuration from a file or environment variables."""
load_dontenv()
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read(self.config_path) config.read(self.config_path)

View file

@ -3,7 +3,7 @@ from .databases.relational import DuckDBAdapter, DatabaseEngine
from .databases.vector.vector_db_interface import VectorDBInterface from .databases.vector.vector_db_interface import VectorDBInterface
from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
from .llm.llm_interface import LLMInterface from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter from .llm.get_llm_client import get_llm_client
from .files.storage import LocalStorage from .files.storage import LocalStorage
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \ from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
@ -35,6 +35,10 @@ class InfrastructureConfig():
chunk_engine = None chunk_engine = None
graph_topology = config.graph_topology graph_topology = config.graph_topology
monitoring_tool = config.monitoring_tool monitoring_tool = config.monitoring_tool
llm_provider: str = None
llm_model: str = None
llm_endpoint: str = None
llm_api_key: str = None
def get_config(self, config_entity: str = None) -> dict: def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None: if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
@ -84,7 +88,8 @@ class InfrastructureConfig():
self.graph_topology = config.graph_topology self.graph_topology = config.graph_topology
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model) self.llm_engine = get_llm_client()
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None: if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path self.database_directory_path = self.system_root_directory + "/" + config.db_path
@ -115,8 +120,8 @@ class InfrastructureConfig():
from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
self.vector_engine = QDrantAdapter( self.vector_engine = QDrantAdapter(
qdrant_url = config.qdrant_url, url = config.qdrant_url,
qdrant_api_key = config.qdrant_api_key, api_key = config.qdrant_api_key,
embedding_engine = self.embedding_engine embedding_engine = self.embedding_engine
) )
self.vector_engine_choice = "qdrant" self.vector_engine_choice = "qdrant"
@ -127,11 +132,10 @@ class InfrastructureConfig():
LocalStorage.ensure_directory_exists(lance_db_path) LocalStorage.ensure_directory_exists(lance_db_path)
self.vector_engine = LanceDBAdapter( self.vector_engine = LanceDBAdapter(
uri = lance_db_path, url = lance_db_path,
api_key = None, api_key = None,
embedding_engine = self.embedding_engine, embedding_engine = self.embedding_engine,
) )
self.lance_db_path = lance_db_path
self.vector_engine_choice = "lancedb" self.vector_engine_choice = "lancedb"
if config_entity is not None: if config_entity is not None:

View file

@ -1,6 +1,5 @@
from typing import List, Optional, get_type_hints, Generic, TypeVar from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from pydantic import BaseModel, Field
import lancedb import lancedb
from lancedb.pydantic import Vector, LanceModel from lancedb.pydantic import Vector, LanceModel
from cognee.infrastructure.files.storage import LocalStorage from cognee.infrastructure.files.storage import LocalStorage
@ -9,21 +8,25 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
class LanceDBAdapter(VectorDBInterface): class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
connection: lancedb.AsyncConnection = None connection: lancedb.AsyncConnection = None
def __init__( def __init__(
self, self,
uri: Optional[str], url: Optional[str],
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine, embedding_engine: EmbeddingEngine,
): ):
self.uri = uri self.url = url
self.api_key = api_key self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
async def get_connection(self): async def get_connection(self):
if self.connection is None: if self.connection is None:
self.connection = await lancedb.connect_async(self.uri, api_key = self.api_key) self.connection = await lancedb.connect_async(self.url, api_key = self.api_key)
return self.connection return self.connection
@ -35,12 +38,12 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel): async def create_collection(self, collection_name: str, payload_schema = None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel): class LanceDataPoint(LanceModel):
id: data_point_types["id"] = Field(...) id: data_point_types["id"]
vector: Vector(vector_size) vector: Vector(vector_size)
payload: payload_schema payload: payload_schema
@ -128,7 +131,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str, collection_name: str,
query_texts: List[str], query_texts: List[str],
limit: int = None, limit: int = None,
with_vector: bool = False, with_vectors: bool = False,
): ):
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -137,11 +140,11 @@ class LanceDBAdapter(VectorDBInterface):
collection_name = collection_name, collection_name = collection_name,
query_vector = query_vector, query_vector = query_vector,
limit = limit, limit = limit,
with_vector = with_vector, with_vector = with_vectors,
) for query_vector in query_vectors] ) for query_vector in query_vectors]
) )
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary
if self.uri.startswith("/"): if self.url.startswith("/"):
LocalStorage.remove_all(self.uri) # Remove the temporary directory and files inside LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside

View file

@ -26,29 +26,29 @@ def create_quantization_config(quantization_config: Dict):
return None return None
class QDrantAdapter(VectorDBInterface): class QDrantAdapter(VectorDBInterface):
qdrant_url: str = None name = "Qdrant"
url: str = None
api_key: str = None
qdrant_path: str = None qdrant_path: str = None
qdrant_api_key: str = None
def __init__(self, qdrant_url, qdrant_api_key, embedding_engine: EmbeddingEngine, qdrant_path = None): def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
if qdrant_path is not None: if qdrant_path is not None:
self.qdrant_path = qdrant_path self.qdrant_path = qdrant_path
else: else:
self.qdrant_url = qdrant_url self.url = url
self.api_key = api_key
self.qdrant_api_key = qdrant_api_key
def get_qdrant_client(self) -> AsyncQdrantClient: def get_qdrant_client(self) -> AsyncQdrantClient:
if self.qdrant_path is not None: if self.qdrant_path is not None:
return AsyncQdrantClient( return AsyncQdrantClient(
path = self.qdrant_path, port=6333 path = self.qdrant_path, port=6333
) )
elif self.qdrant_url is not None: elif self.url is not None:
return AsyncQdrantClient( return AsyncQdrantClient(
url = self.qdrant_url, url = self.url,
api_key = self.qdrant_api_key, api_key = self.api_key,
port = 6333 port = 6333
) )

View file

@ -1,7 +1,6 @@
import asyncio import asyncio
from uuid import UUID from uuid import UUID
from typing import List, Optional from typing import List, Optional
from multiprocessing import Pool
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -9,19 +8,24 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
class WeaviateAdapter(VectorDBInterface): class WeaviateAdapter(VectorDBInterface):
async_pool: Pool = None name = "Weaviate"
url: str
api_key: str
embedding_engine: EmbeddingEngine = None embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine): def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
import weaviate import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs( self.client = weaviate.connect_to_wcs(
cluster_url=url, cluster_url = url,
auth_credentials=weaviate.auth.AuthApiKey(api_key), auth_credentials = weaviate.auth.AuthApiKey(api_key),
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)) additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init=30))
) )
async def embed_data(self, data: List[str]) -> List[float]: async def embed_data(self, data: List[str]) -> List[float]:

View file

@ -0,0 +1 @@
from .config import llm_config

View file

@ -8,7 +8,9 @@ from cognee.infrastructure.llm.prompts import read_query_prompt
class AnthropicAdapter(LLMInterface): class AnthropicAdapter(LLMInterface):
"""Adapter for Ollama's API""" """Adapter for Anthropic API"""
name = "Anthropic"
model: str
def __init__(self, model: str = None): def __init__(self, model: str = None):
self.aclient = instructor.patch( self.aclient = instructor.patch(

View file

@ -0,0 +1,16 @@
class LLMConfig():
llm_provider: str = None
llm_model: str = None
llm_endpoint: str = None
llm_api_key: str = None
def to_dict(self) -> dict:
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"apiKey": self.llm_api_key,
}
llm_config = LLMConfig()

View file

@ -1,10 +1,8 @@
import asyncio import asyncio
import os
from typing import List, Type from typing import List, Type
from pydantic import BaseModel from pydantic import BaseModel
import instructor import instructor
from tenacity import retry, stop_after_attempt from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI
import openai import openai
from cognee.config import Config from cognee.config import Config
@ -19,23 +17,31 @@ config.load()
if config.monitoring_tool == MonitoringTool.LANGFUSE: if config.monitoring_tool == MonitoringTool.LANGFUSE:
from langfuse.openai import AsyncOpenAI, OpenAI from langfuse.openai import AsyncOpenAI, OpenAI
elif config.monitoring_tool == MonitoringTool.LANGSMITH: elif config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrap_openai from langsmith import wrappers
from openai import AsyncOpenAI from openai import AsyncOpenAI
AsyncOpenAI = wrap_openai(AsyncOpenAI()) AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
else: else:
from openai import AsyncOpenAI, OpenAI from openai import AsyncOpenAI, OpenAI
class GenericAPIAdapter(LLMInterface): class GenericAPIAdapter(LLMInterface):
"""Adapter for Generic API LLM provider API """ """Adapter for Generic API LLM provider API """
name: str
model: str
api_key: str
def __init__(self, api_endpoint, api_key: str, model: str): def __init__(self, api_endpoint, api_key: str, model: str, name: str):
self.name = name
self.model = model
self.api_key = api_key
if infrastructure_config.get_config()["llm_provider"] == "groq":
if infrastructure_config.get_config()["llm_provider"] == 'groq':
from groq import groq from groq import groq
self.aclient = instructor.from_openai(client = groq.Groq( self.aclient = instructor.from_openai(
api_key=api_key, client = groq.Groq(
), mode=instructor.Mode.MD_JSON) api_key = api_key,
),
mode = instructor.Mode.MD_JSON
)
else: else:
self.aclient = instructor.patch( self.aclient = instructor.patch(
AsyncOpenAI( AsyncOpenAI(
@ -45,9 +51,6 @@ class GenericAPIAdapter(LLMInterface):
mode = instructor.Mode.JSON, mode = instructor.Mode.JSON,
) )
self.model = model
@retry(stop = stop_after_attempt(5)) @retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs): def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff""" """Wrapper around ChatCompletion.create w/ backoff"""

View file

@ -1,9 +1,8 @@
"""Get the LLM client.""" """Get the LLM client."""
from enum import Enum from enum import Enum
from cognee.config import Config import json
from .anthropic.adapter import AnthropicAdapter import logging
from .openai.adapter import OpenAIAdapter from cognee.infrastructure.llm import llm_config
from .generic_llm_api.adapter import GenericAPIAdapter
# Define an Enum for LLM Providers # Define an Enum for LLM Providers
class LLMProvider(Enum): class LLMProvider(Enum):
@ -12,20 +11,22 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic" ANTHROPIC = "anthropic"
CUSTOM = "custom" CUSTOM = "custom"
config = Config()
config.load()
def get_llm_client(): def get_llm_client():
"""Get the LLM client based on the configuration using Enums.""" """Get the LLM client based on the configuration using Enums."""
provider = LLMProvider(config.llm_provider) logging.error(json.dumps(llm_config.to_dict()))
provider = LLMProvider(llm_config.llm_provider)
if provider == LLMProvider.OPENAI: if provider == LLMProvider.OPENAI:
return OpenAIAdapter(config.openai_key, config.openai_model) from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model)
elif provider == LLMProvider.OLLAMA: elif provider == LLMProvider.OLLAMA:
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model) from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC: elif provider == LLMProvider.ANTHROPIC:
return AnthropicAdapter(config.custom_model) from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM: elif provider == LLMProvider.CUSTOM:
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model) from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else: else:
raise ValueError(f"Unsupported LLM provider: {provider}") raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -23,12 +23,16 @@ else:
from openai import AsyncOpenAI, OpenAI from openai import AsyncOpenAI, OpenAI
class OpenAIAdapter(LLMInterface): class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
"""Adapter for OpenAI's GPT-3, GPT=4 API""" """Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model:str): def __init__(self, api_key: str, model:str):
openai.api_key = api_key self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.aclient = instructor.from_openai(AsyncOpenAI()) self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI())
self.model = model self.model = model
self.api_key = api_key
@retry(stop = stop_after_attempt(5)) @retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs): def completions_with_backoff(self, **kwargs):

View file

@ -36,7 +36,7 @@ def evaluate():
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096) evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096) gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4) compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json")) compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
@ -58,7 +58,7 @@ def evaluate():
return dsp.answer_match(example.answer, [answer_prediction.answer], frac = 0.8) or \ return dsp.answer_match(example.answer, [answer_prediction.answer], frac = 0.8) or \
dsp.passage_match([example.answer], [answer_prediction.answer]) dsp.passage_match([example.answer], [answer_prediction.answer])
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096) gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
dspy.settings.configure(lm = gpt4) dspy.settings.configure(lm = gpt4)
evaluate_on_hotpotqa(compiled_extract_knowledge_graph, metric = evaluate_answer) evaluate_on_hotpotqa(compiled_extract_knowledge_graph, metric = evaluate_answer)

View file

@ -7,7 +7,7 @@ config = Config()
config.load() config.load()
def run(): def run():
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096) gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4) compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json")) compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))

View file

@ -59,7 +59,7 @@ def train():
trainset = [example.with_inputs("context", "question") for example in train_examples] trainset = [example.with_inputs("context", "question") for example in train_examples]
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096) gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = optimizer.compile(ExtractKnowledgeGraph(lm = gpt4), trainset = trainset) compiled_extract_knowledge_graph = optimizer.compile(ExtractKnowledgeGraph(lm = gpt4), trainset = trainset)

View file

@ -41,7 +41,7 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
class ExtractKnowledgeGraph(dspy.Module): class ExtractKnowledgeGraph(dspy.Module):
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)): def __init__(self, lm = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)):
super().__init__() super().__init__()
self.lm = lm self.lm = lm
dspy.settings.configure(lm=self.lm) dspy.settings.configure(lm=self.lm)
@ -50,7 +50,7 @@ class ExtractKnowledgeGraph(dspy.Module):
def forward(self, context: str, question: str): def forward(self, context: str, question: str):
context = remove_stop_words(context) context = remove_stop_words(context)
context = trim_text_to_max_tokens(context, 1500, config.openai_model) context = trim_text_to_max_tokens(context, 1500, config.llm_model)
with dspy.context(lm = self.lm): with dspy.context(lm = self.lm):
graph = self.generate_graph(text = context).graph graph = self.generate_graph(text = context).graph
@ -79,7 +79,7 @@ def remove_stop_words(text):
# #
# if __name__ == "__main__": # if __name__ == "__main__":
# gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.openai_key, model_type="chat") # gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.llm_api_key, model_type="chat")
# dspy.settings.configure(lm=gpt_4_turbo) # dspy.settings.configure(lm=gpt_4_turbo)

View file

@ -4,7 +4,7 @@
from typing import Union, Dict from typing import Union, Dict
import networkx as nx import networkx as nx
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]: async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
""" """
Find the neighbours of a given node in the graph and return their descriptions. Find the neighbours of a given node in the graph and return their descriptions.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration. Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
@ -12,13 +12,12 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
Parameters: Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session. - graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): Unused in this implementation but could be used for future enhancements. - query (str): Unused in this implementation but could be used for future enhancements.
- infrastructure_config (Dict): Configuration that includes the graph engine type.
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node. - other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
Returns: Returns:
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node. - Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
""" """
node_id = other_param.get('node_id') if other_param else None node_id = other_param.get('node_id') if other_param else query
if node_id is None: if node_id is None:
return {} return {}

View file

@ -7,7 +7,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
import networkx as nx import networkx as nx
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str, async def search_neighbour(graph: Union[nx.Graph, any], query: str,
other_param: dict = None): other_param: dict = None):
""" """
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions. Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
@ -23,8 +23,7 @@ async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node. - List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
""" """
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
if node_id is None: node_id = other_param.get('node_id') if other_param else query
node_id = other_param.get('node_id') if other_param else None
if node_id is None: if node_id is None:
return [] return []

View file

@ -1,41 +1,84 @@
from cognee.config import Config from cognee.config import Config
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.llm import llm_config
config = Config()
config.load()
def get_settings(): def get_settings():
vector_engine_choice = infrastructure_config.get_config()["vector_engine_choice"] config = Config()
vector_db_options = [{ config.load()
"value": "weaviate",
"label": "Weaviate", vector_dbs = [{
"value": "weaviate",
"label": "Weaviate",
}, { }, {
"value": "qdrant", "value": "qdrant",
"label": "Qdrant", "label": "Qdrant",
}, { }, {
"value": "lancedb", "value": "lancedb",
"label": "LanceDB", "label": "LanceDB",
}] }]
vector_db_config = dict( vector_engine = infrastructure_config.get_config("vector_engine")
url = config.weaviate_url,
apiKey = config.weaviate_api_key, llm_providers = [{
choice = vector_db_options[0], "value": "openai",
options = vector_db_options, "label": "OpenAI",
) if vector_engine_choice == "weaviate" else dict( }, {
url = config.qdrant_url, "value": "ollama",
apiKey = config.qdrant_api_key, "label": "Ollama",
choice = vector_db_options[1], }, {
options = vector_db_options, "value": "anthropic",
) if vector_engine_choice == "qdrant" else dict( "label": "Anthropic",
url = infrastructure_config.get_config("lance_db_path"), }]
choice = vector_db_options[2],
options = vector_db_options,
)
return dict( return dict(
llm = dict( llm = {
openAIApiKey = config.openai_key[:-10] + "**********", "provider": {
), "label": llm_config.llm_provider,
vectorDB = vector_db_config, "value": llm_config.llm_provider,
} if llm_config.llm_provider else llm_providers[0],
"model": {
"value": llm_config.llm_model,
"label": llm_config.llm_model,
} if llm_config.llm_model else None,
"apiKey": llm_config.llm_api_key[:-10] + "**********" if llm_config.llm_api_key else None,
"providers": llm_providers,
"models": {
"openai": [{
"value": "gpt-4o",
"label": "gpt-4o",
}, {
"value": "gpt-4-turbo",
"label": "gpt-4-turbo",
}, {
"value": "gpt-3.5-turbo",
"label": "gpt-3.5-turbo",
}],
"ollama": [{
"value": "llama3",
"label": "llama3",
}, {
"value": "mistral",
"label": "mistral",
}],
"anthropic": [{
"value": "Claude 3 Opus",
"label": "Claude 3 Opus",
}, {
"value": "Claude 3 Sonnet",
"label": "Claude 3 Sonnet",
}, {
"value": "Claude 3 Haiku",
"label": "Claude 3 Haiku",
}]
},
},
vectorDB = {
"provider": {
"label": vector_engine.name,
"value": vector_engine.name.lower(),
},
"url": vector_engine.url,
"apiKey": vector_engine.api_key,
"options": vector_dbs,
},
) )

View file

@ -1,15 +1,20 @@
import os import json
import logging
from pydantic import BaseModel from pydantic import BaseModel
from cognee.config import Config from cognee.infrastructure.llm import llm_config
from cognee.infrastructure import infrastructure_config
config = Config()
class LLMConfig(BaseModel): class LLMConfig(BaseModel):
openAIApiKey: str apiKey: str
model: str
provider: str
async def save_llm_config(llm_config: LLMConfig): async def save_llm_config(new_llm_config: LLMConfig):
if "*" in llm_config.openAIApiKey: llm_config.llm_provider = new_llm_config.provider
return llm_config.llm_model = new_llm_config.model
os.environ["OPENAI_API_KEY"] = llm_config.openAIApiKey if "*****" not in new_llm_config.apiKey and len(new_llm_config.apiKey.strip()) > 0:
config.load() llm_config.llm_api_key = new_llm_config.apiKey
logging.error(json.dumps(llm_config.to_dict()))
infrastructure_config.llm_engine = None

View file

@ -7,24 +7,24 @@ from cognee.infrastructure import infrastructure_config
config = Config() config = Config()
class VectorDBConfig(BaseModel): class VectorDBConfig(BaseModel):
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
url: str url: str
apiKey: str apiKey: str
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
async def save_vector_db_config(vector_db_config: VectorDBConfig): async def save_vector_db_config(vector_db_config: VectorDBConfig):
if vector_db_config.choice == "weaviate": if vector_db_config.provider == "weaviate":
os.environ["WEAVIATE_URL"] = vector_db_config.url os.environ["WEAVIATE_URL"] = vector_db_config.url
os.environ["WEAVIATE_API_KEY"] = vector_db_config.apiKey os.environ["WEAVIATE_API_KEY"] = vector_db_config.apiKey
remove_qdrant_config() remove_qdrant_config()
if vector_db_config.choice == "qdrant": if vector_db_config.provider == "qdrant":
os.environ["QDRANT_URL"] = vector_db_config.url os.environ["QDRANT_URL"] = vector_db_config.url
os.environ["QDRANT_API_KEY"] = vector_db_config.apiKey os.environ["QDRANT_API_KEY"] = vector_db_config.apiKey
remove_weaviate_config() remove_weaviate_config()
if vector_db_config.choice == "lancedb": if vector_db_config.provider == "lancedb":
remove_qdrant_config() remove_qdrant_config()
remove_weaviate_config() remove_weaviate_config()

View file

@ -43,6 +43,7 @@ services:
limits: limits:
cpus: "4.0" cpus: "4.0"
memory: 8GB memory: 8GB
frontend: frontend:
container_name: frontend container_name: frontend
build: build:
@ -55,7 +56,6 @@ services:
networks: networks:
- cognee_backend - cognee_backend
postgres: postgres:
image: postgres image: postgres
container_name: postgres container_name: postgres
@ -69,18 +69,7 @@ services:
- cognee_backend - cognee_backend
ports: ports:
- "5432:5432" - "5432:5432"
litellm:
build:
context: .
args:
target: runtime
image: ghcr.io/berriai/litellm:main-latest
ports:
- "4000:4000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
falkordb: falkordb:
image: falkordb/falkordb:edge image: falkordb/falkordb:edge
container_name: falkordb container_name: falkordb

8521
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -60,7 +60,6 @@ tiktoken = "^0.6.0"
dspy-ai = "2.4.3" dspy-ai = "2.4.3"
posthog = "^3.5.0" posthog = "^3.5.0"
lancedb = "^0.6.10" lancedb = "^0.6.10"
importlib-metadata = "6.8.0" importlib-metadata = "6.8.0"
litellm = "^1.37.3" litellm = "^1.37.3"
groq = "^0.5.0" groq = "^0.5.0"