feat: add llm config

This commit is contained in:
Boris Arzentar 2024-05-22 22:36:30 +02:00
parent 9bb30bc43a
commit 84c0c8cab5
30 changed files with 9034 additions and 214 deletions

1
.gitignore vendored
View file

@ -169,3 +169,4 @@ cognee/cache/
# Default cognee system directory, used in development
.cognee_system/
.data_storage/

View file

@ -228,6 +228,126 @@
"node": ">= 10"
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.3.tgz",
"integrity": "sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.3.tgz",
"integrity": "sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.3.tgz",
"integrity": "sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.3.tgz",
"integrity": "sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.3.tgz",
"integrity": "sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.3.tgz",
"integrity": "sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==",
"cpu": [
"arm64"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-ia32-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.3.tgz",
"integrity": "sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==",
"cpu": [
"ia32"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "14.2.3",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.3.tgz",
"integrity": "sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==",
"cpu": [
"x64"
],
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">= 10"
}
},
"node_modules/@nodelib/fs.scandir": {
"version": "2.1.5",
"dev": true,

View file

@ -73,6 +73,8 @@ function useDatasets() {
if (datasets.length > 0) {
checkDatasetStatuses(datasets);
} else {
window.location.href = '/wizard';
}
});
}, [checkDatasetStatuses]);

View file

@ -15,18 +15,24 @@
flex: 1;
padding: 16px;
border-top: 2px solid white;
overflow: hidden;
}
.messagesContainer {
flex: 1;
overflow-y: auto;
}
.messages {
flex: 1;
padding-top: 24px;
padding-bottom: 24px;
overflow-y: auto;
}
.message {
padding: 16px;
border-radius: var(--border-radius);
width: max-content;
}
.userMessage {

View file

@ -1,4 +1,4 @@
import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text } from 'ohmy-ui';
import { CTAButton, CloseIcon, GhostButton, Input, Spacer, Stack, Text, DropdownSelect } from 'ohmy-ui';
import styles from './SearchView.module.css';
import { useCallback, useState } from 'react';
import { v4 } from 'uuid';
@ -22,10 +22,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
setInputValue(event.target.value);
}, []);
const searchOptions = [{
value: 'SIMILARITY',
label: 'Similarity',
}, {
value: 'NEIGHBOR',
label: 'Neighbor',
}, {
value: 'SUMMARY',
label: 'Summary',
}, {
value: 'ADJACENT',
label: 'Adjacent',
}, {
value: 'CATEGORIES',
label: 'Categories',
}];
const [searchType, setSearchType] = useState(searchOptions[0]);
const handleSearchSubmit = useCallback((event: React.FormEvent<HTMLFormElement>) => {
event.preventDefault();
setMessages((currentMessages) => [
...currentMessages,
{
@ -43,6 +60,7 @@ export default function SearchView({ onClose }: SearchViewProps) {
body: JSON.stringify({
query_params: {
query: inputValue,
searchType: searchType.value,
},
}),
})
@ -58,8 +76,8 @@ export default function SearchView({ onClose }: SearchViewProps) {
]);
setInputValue('');
})
}, [inputValue]);
}, [inputValue, searchType]);
return (
<Stack className={styles.searchViewContainer}>
<Stack gap="between" align="center/" orientation="horizontal">
@ -71,20 +89,27 @@ export default function SearchView({ onClose }: SearchViewProps) {
</GhostButton>
</Stack>
<Stack className={styles.searchContainer}>
<Stack gap="2" className={styles.messages} align="end">
{messages.map((message) => (
<Text
key={message.id}
className={classNames(styles.message, {
[styles.userMessage]: message.user === "user",
})}
>
{message.text}
</Text>
))}
</Stack>
<div className={styles.messagesContainer}>
<Stack gap="2" className={styles.messages} align="end">
{messages.map((message) => (
<Text
key={message.id}
className={classNames(styles.message, {
[styles.userMessage]: message.user === "user",
})}
>
{message.text}
</Text>
))}
</Stack>
</div>
<form onSubmit={handleSearchSubmit}>
<Stack orientation="horizontal" gap="2">
<DropdownSelect
value={searchType}
options={searchOptions}
onChange={setSearchType}
/>
<Input value={inputValue} onChange={handleInputChange} name="searchInput" placeholder="Search" />
<CTAButton type="submit">Search</CTAButton>
</Stack>

View file

@ -7,12 +7,22 @@ interface SelectOption {
}
export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
const [llmConfig, setLLMConfig] = useState<{ openAIApiKey: string }>();
const [llmConfig, setLLMConfig] = useState<{
apiKey: string;
model: SelectOption;
models: {
openai: SelectOption[];
ollama: SelectOption[];
anthropic: SelectOption[];
};
provider: SelectOption;
providers: SelectOption[];
}>();
const [vectorDBConfig, setVectorDBConfig] = useState<{
choice: SelectOption;
options: SelectOption[];
url: string;
apiKey: string;
provider: SelectOption;
options: SelectOption[];
}>();
const {
@ -23,10 +33,18 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
const saveConfig = (event: React.FormEvent<HTMLFormElement>) => {
event.preventDefault();
const newOpenAIApiKey = event.target.openAIApiKey.value;
const newVectorDBChoice = vectorDBConfig?.choice.value;
const newVectorDBUrl = event.target.vectorDBUrl.value;
const newVectorDBApiKey = event.target.vectorDBApiKey.value;
const newVectorConfig = {
provider: vectorDBConfig?.provider.value,
url: event.target.vectorDBUrl.value,
apiKey: event.target.vectorDBApiKey.value,
};
const newLLMConfig = {
provider: llmConfig?.provider.value,
model: llmConfig?.model.value,
apiKey: event.target.llmApiKey.value,
};
startSaving();
@ -36,14 +54,8 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
'Content-Type': 'application/json',
},
body: JSON.stringify({
llm: {
openAIApiKey: newOpenAIApiKey,
},
vectorDB: {
choice: newVectorDBChoice,
url: newVectorDBUrl,
apiKey: newVectorDBApiKey,
},
llm: newLLMConfig,
vectorDB: newVectorConfig,
}),
})
.then(() => {
@ -52,12 +64,12 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
.finally(() => stopSaving());
};
const handleVectorDBChange = useCallback((newChoice: SelectOption) => {
const handleVectorDBChange = useCallback((newVectorDBProvider: SelectOption) => {
setVectorDBConfig((config) => {
if (config?.choice !== newChoice) {
if (config?.provider !== newVectorDBProvider) {
return {
...config,
choice: newChoice,
provider: newVectorDBProvider,
url: '',
apiKey: '',
};
@ -66,11 +78,40 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
});
}, []);
const handleLLMProviderChange = useCallback((newLLMProvider: SelectOption) => {
setLLMConfig((config) => {
if (config?.provider !== newLLMProvider) {
return {
...config,
provider: newLLMProvider,
model: config?.models[newLLMProvider.value][0],
apiKey: '',
};
}
return config;
});
}, []);
const handleLLMModelChange = useCallback((newLLMModel: SelectOption) => {
setLLMConfig((config) => {
if (config?.model !== newLLMModel) {
return {
...config,
model: newLLMModel,
};
}
return config;
});
}, []);
useEffect(() => {
const fetchVectorDBChoices = async () => {
const response = await fetch('http://0.0.0.0:8000/settings');
const settings = await response.json();
if (!settings.llm.model) {
settings.llm.model = settings.llm.models[settings.llm.provider.value][0];
}
setLLMConfig(settings.llm);
setVectorDBConfig(settings.vectorDB);
};
@ -79,40 +120,54 @@ export default function SettingsModal({ isOpen = false, onClose = () => {} }) {
return (
<Modal isOpen={isOpen} onClose={onClose}>
<Stack gap="4" orientation="vertical" align="center/">
<Stack gap="8" orientation="vertical" align="center/">
<H2>Settings</H2>
<form onSubmit={saveConfig} style={{ width: '100%' }}>
<Stack gap="2" orientation="vertical">
<H3>LLM Config</H3>
<FormGroup orientation="vertical" align="center/" gap="1">
<FormLabel>OpenAI API Key</FormLabel>
<Stack gap="4" orientation="vertical">
<Stack gap="2" orientation="vertical">
<H3>LLM Config</H3>
<FormGroup orientation="horizontal" align="center/" gap="4">
<FormLabel>LLM provider:</FormLabel>
<DropdownSelect
value={llmConfig?.provider}
options={llmConfig?.providers}
onChange={handleLLMProviderChange}
/>
</FormGroup>
<FormGroup orientation="horizontal" align="center/" gap="4">
<FormLabel>LLM model:</FormLabel>
<DropdownSelect
value={llmConfig?.model}
options={llmConfig?.provider ? llmConfig?.models[llmConfig?.provider.value] : []}
onChange={handleLLMModelChange}
/>
</FormGroup>
<FormInput>
<Input defaultValue={llmConfig?.openAIApiKey} name="openAIApiKey" placeholder="OpenAI API Key" />
<Input defaultValue={llmConfig?.apiKey} name="llmApiKey" placeholder="LLM API key" />
</FormInput>
</FormGroup>
</Stack>
<H3>Vector Database Config</H3>
<DropdownSelect
value={vectorDBConfig?.choice}
options={vectorDBConfig?.options}
onChange={handleVectorDBChange}
/>
<FormGroup orientation="vertical" align="center/" gap="1">
<FormLabel>Vector DB url</FormLabel>
<Stack gap="2" orientation="vertical">
<H3>Vector Database Config</H3>
<FormGroup orientation="horizontal" align="center/" gap="4">
<FormLabel>Vector DB provider:</FormLabel>
<DropdownSelect
value={vectorDBConfig?.provider}
options={vectorDBConfig?.options}
onChange={handleVectorDBChange}
/>
</FormGroup>
<FormInput>
<Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB API url" />
<Input defaultValue={vectorDBConfig?.url} name="vectorDBUrl" placeholder="Vector DB instance url" />
</FormInput>
</FormGroup>
<FormGroup orientation="vertical" align="center/" gap="1">
<FormLabel>Vector DB API key</FormLabel>
<FormInput>
<Input defaultValue={vectorDBConfig?.apiKey} name="vectorDBApiKey" placeholder="Vector DB API key" />
</FormInput>
</FormGroup>
<Stack align="/end">
<Spacer top="2">
<CTAButton type="submit">Save</CTAButton>
</Spacer>
<Stack align="/end">
<Spacer top="2">
<CTAButton type="submit">Save</CTAButton>
</Spacer>
</Stack>
</Stack>
</Stack>
</form>

View file

@ -3,6 +3,7 @@ import os
import aiohttp
import uvicorn
import json
import logging
# Set up logging
@ -16,7 +17,7 @@ from cognee.config import Config
config = Config()
config.load()
from typing import Dict, Any, List, Union, Annotated, Literal
from typing import Dict, Any, List, Union, Annotated, Literal, Optional
from fastapi import FastAPI, HTTPException, Form, File, UploadFile, Query
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
@ -25,6 +26,7 @@ from pydantic import BaseModel
app = FastAPI(debug=True)
origins = [
"http://frontend:3000",
"http://localhost:3000",
"http://localhost:3001",
]
@ -220,8 +222,16 @@ async def search(payload: SearchPayload):
from cognee import search as cognee_search
try:
search_type = "SIMILARITY"
await cognee_search(search_type, payload.query_params)
search_type = payload.query_params["searchType"]
params = {
"query": payload.query_params["query"],
}
results = await cognee_search(search_type, params)
return JSONResponse(
status_code = 200,
content = json.dumps(results)
)
except Exception as error:
return JSONResponse(
status_code = 409,
@ -236,25 +246,26 @@ async def get_settings():
class LLMConfig(BaseModel):
openAIApiKey: str
provider: Union[Literal["openai"], Literal["ollama"], Literal["anthropic"]]
model: str
apiKey: str
class VectorDBConfig(BaseModel):
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
url: str
apiKey: str
class SettingsPayload(BaseModel):
llm: LLMConfig | None = None
vectorDB: VectorDBConfig | None = None
llm: Optional[LLMConfig] = None
vectorDB: Optional[VectorDBConfig] = None
@app.post("/settings", response_model=dict)
async def save_config(new_settings: SettingsPayload):
from cognee.modules.settings import save_llm_config, save_vector_db_config
if hasattr(new_settings, "llm"):
if new_settings.llm is not None:
await save_llm_config(new_settings.llm)
if hasattr(new_settings, "vectorDB"):
if new_settings.vectorDB is not None:
await save_vector_db_config(new_settings.vectorDB)
return JSONResponse(

View file

@ -10,11 +10,13 @@ from dotenv import load_dotenv
from cognee.root_dir import get_absolute_path
from cognee.shared.data_models import ChunkStrategy, DefaultGraphModel
base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory
dotenv_path = base_dir / ".env"
load_dotenv(dotenv_path=dotenv_path)
def load_dontenv():
base_dir = Path(__file__).resolve().parent.parent
# Load the .env file from the base directory
dotenv_path = base_dir / ".env"
load_dotenv(dotenv_path=dotenv_path, override = True)
load_dontenv()
@dataclass
class Config:
@ -50,16 +52,20 @@ class Config:
graph_filename = os.getenv("GRAPH_NAME", "cognee_graph.pkl")
# Model parameters
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
ollama_key: Optional[str] = "ollama"
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
model_endpoint: str = "openai"
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
llm_provider: str = os.getenv("LLM_PROVIDER", "openai") #openai, or custom or ollama
llm_model: str = os.getenv("LLM_MODEL", None)
llm_api_key: str = os.getenv("LLM_API_KEY", None)
llm_endpoint: str = os.getenv("LLM_ENDPOINT", None)
# custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
# custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
# custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
# ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
# ollama_key: Optional[str] = "ollama"
# ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
# openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
# model_endpoint: str = "openai"
# llm_api_key: Optional[str] = os.getenv("OPENAI_API_KEY")
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
openai_embedding_model = "text-embedding-3-large"
openai_embedding_dimensions = 3072
@ -132,6 +138,7 @@ class Config:
def load(self):
"""Loads the configuration from a file or environment variables."""
load_dontenv()
config = configparser.ConfigParser()
config.read(self.config_path)

View file

@ -3,7 +3,7 @@ from .databases.relational import DuckDBAdapter, DatabaseEngine
from .databases.vector.vector_db_interface import VectorDBInterface
from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter
from .llm.get_llm_client import get_llm_client
from .files.storage import LocalStorage
from .data.chunking.DefaultChunkEngine import DefaultChunkEngine
from ..shared.data_models import GraphDBType, DefaultContentPrediction, KnowledgeGraph, SummarizedContent, \
@ -35,6 +35,10 @@ class InfrastructureConfig():
chunk_engine = None
graph_topology = config.graph_topology
monitoring_tool = config.monitoring_tool
llm_provider: str = None
llm_model: str = None
llm_endpoint: str = None
llm_api_key: str = None
def get_config(self, config_entity: str = None) -> dict:
if (config_entity is None or config_entity == "database_engine") and self.database_engine is None:
@ -84,7 +88,8 @@ class InfrastructureConfig():
self.graph_topology = config.graph_topology
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
self.llm_engine = get_llm_client()
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path
@ -115,8 +120,8 @@ class InfrastructureConfig():
from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
self.vector_engine = QDrantAdapter(
qdrant_url = config.qdrant_url,
qdrant_api_key = config.qdrant_api_key,
url = config.qdrant_url,
api_key = config.qdrant_api_key,
embedding_engine = self.embedding_engine
)
self.vector_engine_choice = "qdrant"
@ -127,11 +132,10 @@ class InfrastructureConfig():
LocalStorage.ensure_directory_exists(lance_db_path)
self.vector_engine = LanceDBAdapter(
uri = lance_db_path,
url = lance_db_path,
api_key = None,
embedding_engine = self.embedding_engine,
)
self.lance_db_path = lance_db_path
self.vector_engine_choice = "lancedb"
if config_entity is not None:

View file

@ -1,6 +1,5 @@
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio
from pydantic import BaseModel, Field
import lancedb
from lancedb.pydantic import Vector, LanceModel
from cognee.infrastructure.files.storage import LocalStorage
@ -9,21 +8,25 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
connection: lancedb.AsyncConnection = None
def __init__(
self,
uri: Optional[str],
url: Optional[str],
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.uri = uri
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
async def get_connection(self):
if self.connection is None:
self.connection = await lancedb.connect_async(self.uri, api_key = self.api_key)
self.connection = await lancedb.connect_async(self.url, api_key = self.api_key)
return self.connection
@ -35,12 +38,12 @@ class LanceDBAdapter(VectorDBInterface):
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
async def create_collection(self, collection_name: str, payload_schema = None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel):
id: data_point_types["id"] = Field(...)
id: data_point_types["id"]
vector: Vector(vector_size)
payload: payload_schema
@ -128,7 +131,7 @@ class LanceDBAdapter(VectorDBInterface):
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vector: bool = False,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
@ -137,11 +140,11 @@ class LanceDBAdapter(VectorDBInterface):
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vector,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def prune(self):
# Clean up the database if it was set up as temporary
if self.uri.startswith("/"):
LocalStorage.remove_all(self.uri) # Remove the temporary directory and files inside
if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside

View file

@ -26,29 +26,29 @@ def create_quantization_config(quantization_config: Dict):
return None
class QDrantAdapter(VectorDBInterface):
qdrant_url: str = None
name = "Qdrant"
url: str = None
api_key: str = None
qdrant_path: str = None
qdrant_api_key: str = None
def __init__(self, qdrant_url, qdrant_api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path = None):
self.embedding_engine = embedding_engine
if qdrant_path is not None:
self.qdrant_path = qdrant_path
else:
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
self.url = url
self.api_key = api_key
def get_qdrant_client(self) -> AsyncQdrantClient:
if self.qdrant_path is not None:
return AsyncQdrantClient(
path = self.qdrant_path, port=6333
)
elif self.qdrant_url is not None:
elif self.url is not None:
return AsyncQdrantClient(
url = self.qdrant_url,
api_key = self.qdrant_api_key,
url = self.url,
api_key = self.api_key,
port = 6333
)

View file

@ -1,7 +1,6 @@
import asyncio
from uuid import UUID
from typing import List, Optional
from multiprocessing import Pool
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
@ -9,19 +8,24 @@ from ..embeddings.EmbeddingEngine import EmbeddingEngine
class WeaviateAdapter(VectorDBInterface):
async_pool: Pool = None
name = "Weaviate"
url: str
api_key: str
embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
import weaviate
import weaviate.classes as wvc
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs(
cluster_url=url,
auth_credentials=weaviate.auth.AuthApiKey(api_key),
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30))
cluster_url = url,
auth_credentials = weaviate.auth.AuthApiKey(api_key),
additional_config = wvc.init.AdditionalConfig(timeout = wvc.init.Timeout(init=30))
)
async def embed_data(self, data: List[str]) -> List[float]:

View file

@ -0,0 +1 @@
from .config import llm_config

View file

@ -8,7 +8,9 @@ from cognee.infrastructure.llm.prompts import read_query_prompt
class AnthropicAdapter(LLMInterface):
"""Adapter for Ollama's API"""
"""Adapter for Anthropic API"""
name = "Anthropic"
model: str
def __init__(self, model: str = None):
self.aclient = instructor.patch(

View file

@ -0,0 +1,16 @@
class LLMConfig():
llm_provider: str = None
llm_model: str = None
llm_endpoint: str = None
llm_api_key: str = None
def to_dict(self) -> dict:
return {
"provider": self.llm_provider,
"model": self.llm_model,
"endpoint": self.llm_endpoint,
"apiKey": self.llm_api_key,
}
llm_config = LLMConfig()

View file

@ -1,10 +1,8 @@
import asyncio
import os
from typing import List, Type
from pydantic import BaseModel
import instructor
from tenacity import retry, stop_after_attempt
from openai import AsyncOpenAI
import openai
from cognee.config import Config
@ -19,23 +17,31 @@ config.load()
if config.monitoring_tool == MonitoringTool.LANGFUSE:
from langfuse.openai import AsyncOpenAI, OpenAI
elif config.monitoring_tool == MonitoringTool.LANGSMITH:
from langsmith import wrap_openai
from langsmith import wrappers
from openai import AsyncOpenAI
AsyncOpenAI = wrap_openai(AsyncOpenAI())
AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
else:
from openai import AsyncOpenAI, OpenAI
class GenericAPIAdapter(LLMInterface):
"""Adapter for Generic API LLM provider API """
name: str
model: str
api_key: str
def __init__(self, api_endpoint, api_key: str, model: str):
def __init__(self, api_endpoint, api_key: str, model: str, name: str):
self.name = name
self.model = model
self.api_key = api_key
if infrastructure_config.get_config()["llm_provider"] == 'groq':
if infrastructure_config.get_config()["llm_provider"] == "groq":
from groq import groq
self.aclient = instructor.from_openai(client = groq.Groq(
api_key=api_key,
), mode=instructor.Mode.MD_JSON)
self.aclient = instructor.from_openai(
client = groq.Groq(
api_key = api_key,
),
mode = instructor.Mode.MD_JSON
)
else:
self.aclient = instructor.patch(
AsyncOpenAI(
@ -45,9 +51,6 @@ class GenericAPIAdapter(LLMInterface):
mode = instructor.Mode.JSON,
)
self.model = model
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):
"""Wrapper around ChatCompletion.create w/ backoff"""

View file

@ -1,9 +1,8 @@
"""Get the LLM client."""
from enum import Enum
from cognee.config import Config
from .anthropic.adapter import AnthropicAdapter
from .openai.adapter import OpenAIAdapter
from .generic_llm_api.adapter import GenericAPIAdapter
import json
import logging
from cognee.infrastructure.llm import llm_config
# Define an Enum for LLM Providers
class LLMProvider(Enum):
@ -12,20 +11,22 @@ class LLMProvider(Enum):
ANTHROPIC = "anthropic"
CUSTOM = "custom"
config = Config()
config.load()
def get_llm_client():
"""Get the LLM client based on the configuration using Enums."""
provider = LLMProvider(config.llm_provider)
logging.error(json.dumps(llm_config.to_dict()))
provider = LLMProvider(llm_config.llm_provider)
if provider == LLMProvider.OPENAI:
return OpenAIAdapter(config.openai_key, config.openai_model)
from .openai.adapter import OpenAIAdapter
return OpenAIAdapter(llm_config.llm_api_key, llm_config.llm_model)
elif provider == LLMProvider.OLLAMA:
return GenericAPIAdapter(config.ollama_endpoint, config.ollama_key, config.ollama_model)
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
elif provider == LLMProvider.ANTHROPIC:
return AnthropicAdapter(config.custom_model)
from .anthropic.adapter import AnthropicAdapter
return AnthropicAdapter(llm_config.llm_model)
elif provider == LLMProvider.CUSTOM:
return GenericAPIAdapter(config.custom_endpoint, config.custom_key, config.custom_model)
from .generic_llm_api.adapter import GenericAPIAdapter
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
else:
raise ValueError(f"Unsupported LLM provider: {provider}")

View file

@ -23,12 +23,16 @@ else:
from openai import AsyncOpenAI, OpenAI
class OpenAIAdapter(LLMInterface):
name = "OpenAI"
model: str
api_key: str
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
def __init__(self, api_key: str, model:str):
openai.api_key = api_key
self.aclient = instructor.from_openai(AsyncOpenAI())
self.client = instructor.from_openai(OpenAI())
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
self.client = instructor.from_openai(OpenAI(api_key = api_key))
self.model = model
self.api_key = api_key
@retry(stop = stop_after_attempt(5))
def completions_with_backoff(self, **kwargs):

View file

@ -36,7 +36,7 @@ def evaluate():
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
@ -58,7 +58,7 @@ def evaluate():
return dsp.answer_match(example.answer, [answer_prediction.answer], frac = 0.8) or \
dsp.passage_match([example.answer], [answer_prediction.answer])
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
dspy.settings.configure(lm = gpt4)
evaluate_on_hotpotqa(compiled_extract_knowledge_graph, metric = evaluate_answer)

View file

@ -7,7 +7,7 @@ config = Config()
config.load()
def run():
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))

View file

@ -59,7 +59,7 @@ def train():
trainset = [example.with_inputs("context", "question") for example in train_examples]
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
gpt4 = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = optimizer.compile(ExtractKnowledgeGraph(lm = gpt4), trainset = trainset)

View file

@ -41,7 +41,7 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
class ExtractKnowledgeGraph(dspy.Module):
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)):
def __init__(self, lm = dspy.OpenAI(model = config.llm_model, api_key = config.llm_api_key, model_type = "chat", max_tokens = 4096)):
super().__init__()
self.lm = lm
dspy.settings.configure(lm=self.lm)
@ -50,7 +50,7 @@ class ExtractKnowledgeGraph(dspy.Module):
def forward(self, context: str, question: str):
context = remove_stop_words(context)
context = trim_text_to_max_tokens(context, 1500, config.openai_model)
context = trim_text_to_max_tokens(context, 1500, config.llm_model)
with dspy.context(lm = self.lm):
graph = self.generate_graph(text = context).graph
@ -79,7 +79,7 @@ def remove_stop_words(text):
#
# if __name__ == "__main__":
# gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.openai_key, model_type="chat")
# gpt_4_turbo = dspy.OpenAI(model="gpt-4", max_tokens=4000, api_key=config.llm_api_key, model_type="chat")
# dspy.settings.configure(lm=gpt_4_turbo)

View file

@ -4,7 +4,7 @@
from typing import Union, Dict
import networkx as nx
from cognee.shared.data_models import GraphDBType
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
async def search_adjacent(graph: Union[nx.Graph, any], query: str, other_param: dict = None) -> Dict[str, str]:
"""
Find the neighbours of a given node in the graph and return their descriptions.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
@ -12,13 +12,12 @@ async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructur
Parameters:
- graph (Union[nx.Graph, AsyncSession]): The graph object or Neo4j session.
- query (str): Unused in this implementation but could be used for future enhancements.
- infrastructure_config (Dict): Configuration that includes the graph engine type.
- other_param (dict, optional): A dictionary that may contain 'node_id' to specify the node.
Returns:
- Dict[str, str]: A dictionary containing the unique identifiers and descriptions of the neighbours of the given node.
"""
node_id = other_param.get('node_id') if other_param else None
node_id = other_param.get('node_id') if other_param else query
if node_id is None:
return {}

View file

@ -7,7 +7,7 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
import networkx as nx
from cognee.shared.data_models import GraphDBType
async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
async def search_neighbour(graph: Union[nx.Graph, any], query: str,
other_param: dict = None):
"""
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
@ -23,8 +23,7 @@ async def search_neighbour(graph: Union[nx.Graph, any], node_id: str,
- List[str]: A list of 'description' attributes of nodes that share the same 'layer_uuid' with the specified node.
"""
from cognee.infrastructure import infrastructure_config
if node_id is None:
node_id = other_param.get('node_id') if other_param else None
node_id = other_param.get('node_id') if other_param else query
if node_id is None:
return []

View file

@ -1,41 +1,84 @@
from cognee.config import Config
from cognee.infrastructure import infrastructure_config
config = Config()
config.load()
from cognee.infrastructure.llm import llm_config
def get_settings():
vector_engine_choice = infrastructure_config.get_config()["vector_engine_choice"]
vector_db_options = [{
"value": "weaviate",
"label": "Weaviate",
config = Config()
config.load()
vector_dbs = [{
"value": "weaviate",
"label": "Weaviate",
}, {
"value": "qdrant",
"label": "Qdrant",
"value": "qdrant",
"label": "Qdrant",
}, {
"value": "lancedb",
"label": "LanceDB",
"value": "lancedb",
"label": "LanceDB",
}]
vector_db_config = dict(
url = config.weaviate_url,
apiKey = config.weaviate_api_key,
choice = vector_db_options[0],
options = vector_db_options,
) if vector_engine_choice == "weaviate" else dict(
url = config.qdrant_url,
apiKey = config.qdrant_api_key,
choice = vector_db_options[1],
options = vector_db_options,
) if vector_engine_choice == "qdrant" else dict(
url = infrastructure_config.get_config("lance_db_path"),
choice = vector_db_options[2],
options = vector_db_options,
)
vector_engine = infrastructure_config.get_config("vector_engine")
llm_providers = [{
"value": "openai",
"label": "OpenAI",
}, {
"value": "ollama",
"label": "Ollama",
}, {
"value": "anthropic",
"label": "Anthropic",
}]
return dict(
llm = dict(
openAIApiKey = config.openai_key[:-10] + "**********",
),
vectorDB = vector_db_config,
llm = {
"provider": {
"label": llm_config.llm_provider,
"value": llm_config.llm_provider,
} if llm_config.llm_provider else llm_providers[0],
"model": {
"value": llm_config.llm_model,
"label": llm_config.llm_model,
} if llm_config.llm_model else None,
"apiKey": llm_config.llm_api_key[:-10] + "**********" if llm_config.llm_api_key else None,
"providers": llm_providers,
"models": {
"openai": [{
"value": "gpt-4o",
"label": "gpt-4o",
}, {
"value": "gpt-4-turbo",
"label": "gpt-4-turbo",
}, {
"value": "gpt-3.5-turbo",
"label": "gpt-3.5-turbo",
}],
"ollama": [{
"value": "llama3",
"label": "llama3",
}, {
"value": "mistral",
"label": "mistral",
}],
"anthropic": [{
"value": "Claude 3 Opus",
"label": "Claude 3 Opus",
}, {
"value": "Claude 3 Sonnet",
"label": "Claude 3 Sonnet",
}, {
"value": "Claude 3 Haiku",
"label": "Claude 3 Haiku",
}]
},
},
vectorDB = {
"provider": {
"label": vector_engine.name,
"value": vector_engine.name.lower(),
},
"url": vector_engine.url,
"apiKey": vector_engine.api_key,
"options": vector_dbs,
},
)

View file

@ -1,15 +1,20 @@
import os
import json
import logging
from pydantic import BaseModel
from cognee.config import Config
config = Config()
from cognee.infrastructure.llm import llm_config
from cognee.infrastructure import infrastructure_config
class LLMConfig(BaseModel):
openAIApiKey: str
apiKey: str
model: str
provider: str
async def save_llm_config(llm_config: LLMConfig):
if "*" in llm_config.openAIApiKey:
return
async def save_llm_config(new_llm_config: LLMConfig):
llm_config.llm_provider = new_llm_config.provider
llm_config.llm_model = new_llm_config.model
os.environ["OPENAI_API_KEY"] = llm_config.openAIApiKey
config.load()
if "*****" not in new_llm_config.apiKey and len(new_llm_config.apiKey.strip()) > 0:
llm_config.llm_api_key = new_llm_config.apiKey
logging.error(json.dumps(llm_config.to_dict()))
infrastructure_config.llm_engine = None

View file

@ -7,24 +7,24 @@ from cognee.infrastructure import infrastructure_config
config = Config()
class VectorDBConfig(BaseModel):
choice: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
url: str
apiKey: str
provider: Union[Literal["lancedb"], Literal["qdrant"], Literal["weaviate"]]
async def save_vector_db_config(vector_db_config: VectorDBConfig):
if vector_db_config.choice == "weaviate":
if vector_db_config.provider == "weaviate":
os.environ["WEAVIATE_URL"] = vector_db_config.url
os.environ["WEAVIATE_API_KEY"] = vector_db_config.apiKey
remove_qdrant_config()
if vector_db_config.choice == "qdrant":
if vector_db_config.provider == "qdrant":
os.environ["QDRANT_URL"] = vector_db_config.url
os.environ["QDRANT_API_KEY"] = vector_db_config.apiKey
remove_weaviate_config()
if vector_db_config.choice == "lancedb":
if vector_db_config.provider == "lancedb":
remove_qdrant_config()
remove_weaviate_config()

View file

@ -43,6 +43,7 @@ services:
limits:
cpus: "4.0"
memory: 8GB
frontend:
container_name: frontend
build:
@ -55,7 +56,6 @@ services:
networks:
- cognee_backend
postgres:
image: postgres
container_name: postgres
@ -69,18 +69,7 @@ services:
- cognee_backend
ports:
- "5432:5432"
litellm:
build:
context: .
args:
target: runtime
image: ghcr.io/berriai/litellm:main-latest
ports:
- "4000:4000" # Map the container port to the host, change the host port if necessary
volumes:
- ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file
# You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value
command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ]
falkordb:
image: falkordb/falkordb:edge
container_name: falkordb

8521
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -60,7 +60,6 @@ tiktoken = "^0.6.0"
dspy-ai = "2.4.3"
posthog = "^3.5.0"
lancedb = "^0.6.10"
importlib-metadata = "6.8.0"
litellm = "^1.37.3"
groq = "^0.5.0"