feat: COG-585 enable custom llm and embeding models
This commit is contained in:
parent
a8aefd57ef
commit
d1f8217320
11 changed files with 222 additions and 173 deletions
|
|
@ -9,18 +9,21 @@ async def get_graph_engine() -> GraphDBInterface :
|
|||
config = get_graph_config()
|
||||
|
||||
if config.graph_database_provider == "neo4j":
|
||||
try:
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
return Neo4jAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return Neo4jAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password
|
||||
)
|
||||
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
if not (config.graph_database_url and config.graph_database_username and config.graph_database_password):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
from cognee.infrastructure.databases.hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -10,26 +10,29 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
if config["vector_db_provider"] == "weaviate":
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
|
||||
if config["vector_db_url"] is None and config["vector_db_key"] is None:
|
||||
raise EnvironmentError("Weaviate is not configured!")
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(
|
||||
config["vector_db_url"],
|
||||
config["vector_db_key"],
|
||||
embedding_engine = embedding_engine
|
||||
)
|
||||
elif config["vector_db_provider"] == "qdrant":
|
||||
if config["vector_db_url"] and config["vector_db_key"]:
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url = config["vector_db_url"],
|
||||
api_key = config["vector_db_key"],
|
||||
embedding_engine = embedding_engine
|
||||
)
|
||||
elif config["vector_db_provider"] == "qdrant":
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url = config["vector_db_url"],
|
||||
api_key = config["vector_db_key"],
|
||||
embedding_engine = embedding_engine
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
|
|
@ -39,16 +42,25 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
db_port = relational_config.db_port
|
||||
db_name = relational_config.db_name
|
||||
|
||||
if not (db_host and db_port and db_name and db_username and db_password):
|
||||
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
config["vector_db_key"],
|
||||
embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "falkordb":
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
|
|
@ -56,6 +68,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
database_port = config["vector_db_port"],
|
||||
embedding_engine = embedding_engine,
|
||||
)
|
||||
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
|
|
@ -64,5 +77,3 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
api_key = config["vector_db_key"],
|
||||
embedding_engine = embedding_engine,
|
||||
)
|
||||
|
||||
raise EnvironmentError(f"Vector provider not configured correctly: {config['vector_db_provider']}")
|
||||
|
|
|
|||
|
|
@ -1,32 +1,39 @@
|
|||
import asyncio
|
||||
from typing import List, Optional
|
||||
import litellm
|
||||
from litellm import aembedding
|
||||
from cognee.infrastructure.databases.vector.embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
||||
api_key: str
|
||||
embedding_model: str
|
||||
embedding_dimensions: int
|
||||
endpoint: str
|
||||
api_version: str
|
||||
model: str
|
||||
dimensions: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: Optional[str] = "text-embedding-3-large",
|
||||
embedding_dimensions: Optional[int] = 3072,
|
||||
model: Optional[str] = "text-embedding-3-large",
|
||||
dimensions: Optional[int] = 3072,
|
||||
api_key: str = None,
|
||||
endpoint: str = None,
|
||||
api_version: str = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.embedding_model = embedding_model
|
||||
self.embedding_dimensions = embedding_dimensions
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
|
||||
async def embed_text(self, text: List[str]) -> List[List[float]]:
|
||||
async def get_embedding(text_):
|
||||
response = await aembedding(
|
||||
self.embedding_model,
|
||||
response = await litellm.aembedding(
|
||||
self.model,
|
||||
input = text_,
|
||||
api_key = self.api_key
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version
|
||||
)
|
||||
|
||||
return response.data[0]["embedding"]
|
||||
|
|
@ -36,4 +43,4 @@ class LiteLLMEmbeddingEngine(EmbeddingEngine):
|
|||
return result
|
||||
|
||||
def get_vector_size(self) -> int:
|
||||
return self.embedding_dimensions
|
||||
return self.dimensions
|
||||
|
|
|
|||
|
|
@ -1,23 +1,16 @@
|
|||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class EmbeddingConfig(BaseSettings):
|
||||
openai_embedding_model: str = "text-embedding-3-large"
|
||||
openai_embedding_dimensions: int = 3072
|
||||
litellm_embedding_model: str = "BAAI/bge-large-en-v1.5"
|
||||
litellm_embedding_dimensions: int = 1024
|
||||
# embedding_engine:object = DefaultEmbeddingEngine(embedding_model=litellm_embedding_model, embedding_dimensions=litellm_embedding_dimensions)
|
||||
embedding_model: Optional[str] = "text-embedding-3-large"
|
||||
embedding_dimensions: Optional[int] = 3072
|
||||
embedding_endpoint: Optional[str] = None
|
||||
embedding_api_key: Optional[str] = None
|
||||
embedding_api_version: Optional[str] = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file = ".env", extra = "allow")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"openai_embedding_model": self.openai_embedding_model,
|
||||
"openai_embedding_dimensions": self.openai_embedding_dimensions,
|
||||
"litellm_embedding_model": self.litellm_embedding_model,
|
||||
"litellm_embedding_dimensions": self.litellm_embedding_dimensions,
|
||||
}
|
||||
|
||||
@lru_cache
|
||||
def get_embedding_config():
|
||||
return EmbeddingConfig()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,17 @@
|
|||
from cognee.infrastructure.llm import get_llm_config
|
||||
from cognee.infrastructure.databases.vector.embeddings.config import get_embedding_config
|
||||
from cognee.infrastructure.llm.config import get_llm_config
|
||||
from .EmbeddingEngine import EmbeddingEngine
|
||||
from .LiteLLMEmbeddingEngine import LiteLLMEmbeddingEngine
|
||||
|
||||
def get_embedding_engine() -> EmbeddingEngine:
|
||||
config = get_embedding_config()
|
||||
llm_config = get_llm_config()
|
||||
return LiteLLMEmbeddingEngine(api_key = llm_config.llm_api_key)
|
||||
|
||||
return LiteLLMEmbeddingEngine(
|
||||
# If OpenAI API is used for embeddings, litellm needs only the api_key.
|
||||
api_key = config.embedding_api_key or llm_config.llm_api_key,
|
||||
endpoint = config.embedding_endpoint,
|
||||
api_version = config.embedding_api_version,
|
||||
model = config.embedding_model,
|
||||
dimensions = config.embedding_dimensions,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ class LLMConfig(BaseSettings):
|
|||
llm_model: str = "gpt-4o-mini"
|
||||
llm_endpoint: str = ""
|
||||
llm_api_key: Optional[str] = None
|
||||
llm_api_version: Optional[str] = None
|
||||
llm_temperature: float = 0.0
|
||||
llm_streaming: bool = False
|
||||
transcription_model: str = "whisper-1"
|
||||
|
|
@ -19,6 +20,7 @@ class LLMConfig(BaseSettings):
|
|||
"model": self.llm_model,
|
||||
"endpoint": self.llm_endpoint,
|
||||
"api_key": self.llm_api_key,
|
||||
"api_version": self.llm_api_version,
|
||||
"temperature": self.llm_temperature,
|
||||
"streaming": self.llm_streaming,
|
||||
"transcription_model": self.transcription_model
|
||||
|
|
|
|||
|
|
@ -20,21 +20,33 @@ def get_llm_client():
|
|||
raise ValueError("LLM API key is not set.")
|
||||
|
||||
from .openai.adapter import OpenAIAdapter
|
||||
return OpenAIAdapter(api_key=llm_config.llm_api_key, model=llm_config.llm_model, transcription_model=llm_config.transcription_model, streaming=llm_config.llm_streaming)
|
||||
|
||||
return OpenAIAdapter(
|
||||
api_key = llm_config.llm_api_key,
|
||||
endpoint = llm_config.llm_endpoint,
|
||||
api_version = llm_config.llm_api_version,
|
||||
model = llm_config.llm_model,
|
||||
transcription_model = llm_config.transcription_model,
|
||||
streaming = llm_config.llm_streaming,
|
||||
)
|
||||
|
||||
elif provider == LLMProvider.OLLAMA:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Ollama")
|
||||
|
||||
elif provider == LLMProvider.ANTHROPIC:
|
||||
from .anthropic.adapter import AnthropicAdapter
|
||||
return AnthropicAdapter(llm_config.llm_model)
|
||||
|
||||
elif provider == LLMProvider.CUSTOM:
|
||||
if llm_config.llm_api_key is None:
|
||||
raise ValueError("LLM API key is not set.")
|
||||
|
||||
from .generic_llm_api.adapter import GenericAPIAdapter
|
||||
return GenericAPIAdapter(llm_config.llm_endpoint, llm_config.llm_api_key, llm_config.llm_model, "Custom")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
|
|
|||
|
|
@ -1,174 +1,121 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Type
|
||||
from typing import Type
|
||||
|
||||
import openai
|
||||
import litellm
|
||||
import instructor
|
||||
from pydantic import BaseModel
|
||||
from tenacity import retry, stop_after_attempt
|
||||
|
||||
from cognee.base_config import get_base_config
|
||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||
# from cognee.shared.data_models import MonitoringTool
|
||||
|
||||
class OpenAIAdapter(LLMInterface):
|
||||
name = "OpenAI"
|
||||
model: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
|
||||
"""Adapter for OpenAI's GPT-3, GPT=4 API"""
|
||||
def __init__(self, api_key: str, model: str, transcription_model:str, streaming: bool = False):
|
||||
base_config = get_base_config()
|
||||
|
||||
# if base_config.monitoring_tool == MonitoringTool.LANGFUSE:
|
||||
# from langfuse.openai import AsyncOpenAI, OpenAI
|
||||
# elif base_config.monitoring_tool == MonitoringTool.LANGSMITH:
|
||||
# from langsmith import wrappers
|
||||
# from openai import AsyncOpenAI
|
||||
# AsyncOpenAI = wrappers.wrap_openai(AsyncOpenAI())
|
||||
# else:
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
self.aclient = instructor.from_openai(AsyncOpenAI(api_key = api_key))
|
||||
self.client = instructor.from_openai(OpenAI(api_key = api_key))
|
||||
self.base_openai_client = OpenAI(api_key = api_key)
|
||||
self.transcription_model = "whisper-1"
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
api_version: str,
|
||||
model: str,
|
||||
transcription_model: str,
|
||||
streaming: bool = False,
|
||||
):
|
||||
self.aclient = instructor.from_litellm(litellm.acompletion)
|
||||
self.client = instructor.from_litellm(litellm.completion)
|
||||
self.transcription_model = transcription_model
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.api_version = api_version
|
||||
self.streaming = streaming
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def completions_with_backoff(self, **kwargs):
|
||||
"""Wrapper around ChatCompletion.create w/ backoff"""
|
||||
return openai.chat.completions.create(**kwargs)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acompletions_with_backoff(self,**kwargs):
|
||||
"""Wrapper around ChatCompletion.acreate w/ backoff"""
|
||||
return await openai.chat.completions.acreate(**kwargs)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
|
||||
return await self.aclient.embeddings.create(input = input, model = model)
|
||||
|
||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
text = text.replace("\n", " ")
|
||||
response = await self.aclient.embeddings.create(input = text, model = model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def create_embedding_with_backoff(self, **kwargs):
|
||||
"""Wrapper around Embedding.create w/ backoff"""
|
||||
return openai.embeddings.create(**kwargs)
|
||||
|
||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting
|
||||
:param text: str
|
||||
:param model: str
|
||||
"""
|
||||
text = text.replace("\n", " ")
|
||||
response = self.create_embedding_with_backoff(input=[text], model=model)
|
||||
embedding = response.data[0].embedding
|
||||
return embedding
|
||||
|
||||
async def async_get_batch_embeddings_with_backoff(self, texts: List[str], models: List[str]):
|
||||
"""To get multiple text embeddings in parallel, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
# Collect all coroutines
|
||||
coroutines = (self.async_get_embedding_with_backoff(text, model)
|
||||
for text, model in zip(texts, models))
|
||||
|
||||
# Run the coroutines in parallel and gather the results
|
||||
embeddings = await asyncio.gather(*coroutines)
|
||||
|
||||
return embeddings
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
return await self.aclient.chat.completions.create(
|
||||
model = self.model,
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
}, {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}],
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version,
|
||||
response_model = response_model,
|
||||
max_retries = 5,
|
||||
)
|
||||
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def create_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||
"""Generate a response from a user query."""
|
||||
|
||||
return self.client.chat.completions.create(
|
||||
model = self.model,
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
},
|
||||
{"role": "system", "content": system_prompt},
|
||||
],
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": f"""Use the given format to
|
||||
extract information from the following input: {text_input}. """,
|
||||
}, {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}],
|
||||
api_key = self.api_key,
|
||||
api_base = self.endpoint,
|
||||
api_version = self.api_version,
|
||||
response_model = response_model,
|
||||
max_retries = 5,
|
||||
)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def create_transcript(self, input):
|
||||
"""Generate a audio transcript from a user query."""
|
||||
|
||||
if not os.path.isfile(input):
|
||||
raise FileNotFoundError(f"The file {input} does not exist.")
|
||||
|
||||
with open(input, 'rb') as audio_file:
|
||||
audio_data = audio_file.read()
|
||||
# with open(input, 'rb') as audio_file:
|
||||
# audio_data = audio_file.read()
|
||||
|
||||
|
||||
|
||||
transcription = self.base_openai_client.audio.transcriptions.create(
|
||||
model=self.transcription_model ,
|
||||
file=Path(input),
|
||||
)
|
||||
transcription = litellm.transcription(
|
||||
model = self.transcription_model,
|
||||
file = Path(input),
|
||||
max_retries = 5,
|
||||
)
|
||||
|
||||
return transcription
|
||||
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
def transcribe_image(self, input) -> BaseModel:
|
||||
with open(input, "rb") as image_file:
|
||||
encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
return self.base_openai_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What’s in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
return litellm.completion(
|
||||
model = self.model,
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What’s in this image?",
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{encoded_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
max_tokens = 300,
|
||||
max_retries = 5,
|
||||
)
|
||||
|
||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||
"""Format and display the prompt for a user query."""
|
||||
if not text_input:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import inspect
|
||||
import logging
|
||||
from cognee.modules.settings import get_current_settings
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
|
|
@ -157,7 +159,7 @@ async def run_tasks_base(tasks: list[Task], data = None, user: User = None):
|
|||
})
|
||||
raise error
|
||||
|
||||
async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pipeline"):
|
||||
async def run_tasks_with_telemetry(tasks: list[Task], data, pipeline_name: str):
|
||||
user = await get_default_user()
|
||||
|
||||
try:
|
||||
|
|
@ -185,3 +187,10 @@ async def run_tasks(tasks: [Task], data = None, pipeline_name: str = "default_pi
|
|||
})
|
||||
|
||||
raise error
|
||||
|
||||
async def run_tasks(tasks: list[Task], data = None, pipeline_name: str = "default_pipeline"):
|
||||
config = get_current_settings()
|
||||
logger.debug("\nRunning pipeline with configuration:\n%s\n", json.dumps(config, indent = 1))
|
||||
|
||||
async for result in run_tasks_with_telemetry(tasks, data, pipeline_name):
|
||||
yield result
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from .get_current_settings import get_current_settings
|
||||
from .get_settings import get_settings, SettingsDict
|
||||
from .save_llm_config import save_llm_config
|
||||
from .save_vector_db_config import save_vector_db_config
|
||||
|
|
|
|||
54
cognee/modules/settings/get_current_settings.py
Normal file
54
cognee/modules/settings/get_current_settings.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from typing import TypedDict
|
||||
from cognee.infrastructure.llm import get_llm_config
|
||||
from cognee.infrastructure.databases.graph import get_graph_config
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.relational.config import get_relational_config
|
||||
|
||||
class LLMConfig(TypedDict):
|
||||
model: str
|
||||
provider: str
|
||||
|
||||
class VectorDBConfig(TypedDict):
|
||||
url: str
|
||||
provider: str
|
||||
|
||||
class GraphDBConfig(TypedDict):
|
||||
url: str
|
||||
provider: str
|
||||
|
||||
class RelationalConfig(TypedDict):
|
||||
url: str
|
||||
provider: str
|
||||
|
||||
class SettingsDict(TypedDict):
|
||||
llm: LLMConfig
|
||||
graph: GraphDBConfig
|
||||
vector: VectorDBConfig
|
||||
relational: RelationalConfig
|
||||
|
||||
def get_current_settings() -> SettingsDict:
|
||||
llm_config = get_llm_config()
|
||||
graph_config = get_graph_config()
|
||||
vector_config = get_vectordb_config()
|
||||
relational_config = get_relational_config()
|
||||
|
||||
return dict(
|
||||
llm = {
|
||||
"provider": llm_config.llm_provider,
|
||||
"model": llm_config.llm_model,
|
||||
},
|
||||
graph = {
|
||||
"provider": graph_config.graph_database_provider,
|
||||
"url": graph_config.graph_database_url or graph_config.graph_file_path,
|
||||
},
|
||||
vector = {
|
||||
"provider": vector_config.vector_db_provider,
|
||||
"url": vector_config.vector_db_url,
|
||||
},
|
||||
relational = {
|
||||
"provider": relational_config.db_provider,
|
||||
"url": f"{relational_config.db_host}:{relational_config.db_port}" \
|
||||
if relational_config.db_host \
|
||||
else f"{relational_config.db_path}/{relational_config.db_name}",
|
||||
},
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue