Add pydantic settings checker (#497)
<!-- .github/pull_request_template.md --> ## Description Add test of embedding and LLM model at beginning of cognee use Fix issue with relational database async use Refactor handling of cache mechanism for all databases so changes in config can be reflected in get functions ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced connection testing for language and embedding services at startup, ensuring improved reliability during data addition. - **Refactor** - Streamlined engine initialization across multiple database systems to enhance performance and clarity. - Improved parameter handling and caching strategies for faster, more consistent operations. - Updated record identifiers for more robust and unique data storage. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: holchan <61059652+holchan@users.noreply.github.com> Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
parent
690d028928
commit
df163b0431
10 changed files with 107 additions and 59 deletions
|
|
@ -16,9 +16,22 @@ async def add(
|
|||
dataset_name: str = "main_dataset",
|
||||
user: User = None,
|
||||
):
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
# Initialize first_run attribute if it doesn't exist
|
||||
if not hasattr(add, "first_run"):
|
||||
add.first_run = True
|
||||
|
||||
if add.first_run:
|
||||
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
|
||||
|
||||
# Test LLM and Embedding configuration once before running Cognee
|
||||
await test_llm_connection()
|
||||
await test_embedding_connection()
|
||||
add.first_run = False # Update flag after first run
|
||||
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,24 @@ class GraphConfig(BaseSettings):
|
|||
return {
|
||||
"graph_filename": self.graph_filename,
|
||||
"graph_database_provider": self.graph_database_provider,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_database_url": self.graph_database_url,
|
||||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
"graph_model": self.graph_model,
|
||||
"graph_topology": self.graph_topology,
|
||||
"model_config": self.model_config,
|
||||
}
|
||||
|
||||
def to_hashable_dict(self) -> dict:
|
||||
return {
|
||||
"graph_database_provider": self.graph_database_provider,
|
||||
"graph_database_url": self.graph_database_url,
|
||||
"graph_database_username": self.graph_database_username,
|
||||
"graph_database_password": self.graph_database_password,
|
||||
"graph_database_port": self.graph_database_port,
|
||||
"graph_file_path": self.graph_file_path,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface
|
|||
|
||||
async def get_graph_engine() -> GraphDBInterface:
|
||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
graph_client = create_graph_engine()
|
||||
config = get_graph_config()
|
||||
|
||||
graph_client = create_graph_engine(**get_graph_config().to_hashable_dict())
|
||||
|
||||
# Async functions can't be cached. After creating and caching the graph engine
|
||||
# handle all necessary async operations for different graph types bellow.
|
||||
config = get_graph_config()
|
||||
|
||||
# Handle loading of graph for NetworkX
|
||||
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
|
||||
await graph_client.load_graph_from_file()
|
||||
|
|
@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface:
|
|||
|
||||
|
||||
@lru_cache
|
||||
def create_graph_engine() -> GraphDBInterface:
|
||||
def create_graph_engine(
|
||||
graph_database_provider,
|
||||
graph_database_url,
|
||||
graph_database_username,
|
||||
graph_database_password,
|
||||
graph_database_port,
|
||||
graph_file_path,
|
||||
):
|
||||
"""Factory function to create the appropriate graph client based on the graph type."""
|
||||
config = get_graph_config()
|
||||
|
||||
if config.graph_database_provider == "neo4j":
|
||||
if not (
|
||||
config.graph_database_url
|
||||
and config.graph_database_username
|
||||
and config.graph_database_password
|
||||
):
|
||||
if graph_database_provider == "neo4j":
|
||||
if not (graph_database_url and graph_database_username and graph_database_password):
|
||||
raise EnvironmentError("Missing required Neo4j credentials.")
|
||||
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
return Neo4jAdapter(
|
||||
graph_database_url=config.graph_database_url,
|
||||
graph_database_username=config.graph_database_username,
|
||||
graph_database_password=config.graph_database_password,
|
||||
graph_database_url=graph_database_url,
|
||||
graph_database_username=graph_database_username,
|
||||
graph_database_password=graph_database_password,
|
||||
)
|
||||
|
||||
elif config.graph_database_provider == "falkordb":
|
||||
if not (config.graph_database_url and config.graph_database_port):
|
||||
elif graph_database_provider == "falkordb":
|
||||
if not (graph_database_url and graph_database_port):
|
||||
raise EnvironmentError("Missing required FalkorDB credentials.")
|
||||
|
||||
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
|
||||
|
|
@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface:
|
|||
embedding_engine = get_embedding_engine()
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=config.graph_database_url,
|
||||
database_port=config.graph_database_port,
|
||||
database_url=graph_database_url,
|
||||
database_port=graph_database_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
|
||||
graph_client = NetworkXAdapter(filename=config.graph_file_path)
|
||||
graph_client = NetworkXAdapter(filename=graph_file_path)
|
||||
|
||||
return graph_client
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def create_relational_engine(
|
||||
db_path: str,
|
||||
db_name: str,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
# from functools import lru_cache
|
||||
|
||||
from .config import get_relational_config
|
||||
from .create_relational_engine import create_relational_engine
|
||||
|
||||
|
||||
# @lru_cache
|
||||
def get_relational_engine():
|
||||
relational_config = get_relational_config()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,49 +1,47 @@
|
|||
from typing import Dict
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class VectorConfig(Dict):
|
||||
vector_db_url: str
|
||||
vector_db_port: str
|
||||
vector_db_key: str
|
||||
vector_db_provider: str
|
||||
|
||||
|
||||
def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||
if config["vector_db_provider"] == "weaviate":
|
||||
@lru_cache
|
||||
def create_vector_engine(
|
||||
embedding_engine,
|
||||
vector_db_url: str,
|
||||
vector_db_port: str,
|
||||
vector_db_key: str,
|
||||
vector_db_provider: str,
|
||||
):
|
||||
if vector_db_provider == "weaviate":
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Weaviate credentials!")
|
||||
|
||||
return WeaviateAdapter(
|
||||
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
|
||||
)
|
||||
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
|
||||
|
||||
elif config["vector_db_provider"] == "qdrant":
|
||||
if not (config["vector_db_url"] and config["vector_db_key"]):
|
||||
elif vector_db_provider == "qdrant":
|
||||
if not (vector_db_url and vector_db_key):
|
||||
raise EnvironmentError("Missing requred Qdrant credentials!")
|
||||
|
||||
from .qdrant.QDrantAdapter import QDrantAdapter
|
||||
|
||||
return QDrantAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "milvus":
|
||||
elif vector_db_provider == "milvus":
|
||||
from .milvus.MilvusAdapter import MilvusAdapter
|
||||
|
||||
if not config["vector_db_url"]:
|
||||
if not vector_db_url:
|
||||
raise EnvironmentError("Missing required Milvus credentials!")
|
||||
|
||||
return MilvusAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
elif vector_db_provider == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
# Get configuration for postgres database
|
||||
|
|
@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
|
||||
return PGVectorAdapter(
|
||||
connection_string,
|
||||
config["vector_db_key"],
|
||||
vector_db_key,
|
||||
embedding_engine,
|
||||
)
|
||||
|
||||
elif config["vector_db_provider"] == "falkordb":
|
||||
if not (config["vector_db_url"] and config["vector_db_port"]):
|
||||
elif vector_db_provider == "falkordb":
|
||||
if not (vector_db_url and vector_db_port):
|
||||
raise EnvironmentError("Missing requred FalkorDB credentials!")
|
||||
|
||||
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
|
||||
|
||||
return FalkorDBAdapter(
|
||||
database_url=config["vector_db_url"],
|
||||
database_port=config["vector_db_port"],
|
||||
database_url=vector_db_url,
|
||||
database_port=vector_db_port,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
return LanceDBAdapter(
|
||||
url=config["vector_db_url"],
|
||||
api_key=config["vector_db_key"],
|
||||
url=vector_db_url,
|
||||
api_key=vector_db_key,
|
||||
embedding_engine=embedding_engine,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from .config import get_vectordb_config
|
||||
from .embeddings import get_embedding_engine
|
||||
from .create_vector_engine import create_vector_engine
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vector_engine():
|
||||
return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine())
|
||||
return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
from typing import List, Optional, get_type_hints
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
||||
|
|
@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
|
@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
# PGVector requires one column to be the primary key
|
||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
|
|
|||
|
|
@ -1,2 +1,4 @@
|
|||
from .config import get_llm_config
|
||||
from .utils import get_max_chunk_tokens
|
||||
from .utils import test_llm_connection
|
||||
from .utils import test_embedding_connection
|
||||
|
|
|
|||
|
|
@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str):
|
|||
logger.info("Model not found in LiteLLM's model_cost.")
|
||||
|
||||
return max_tokens
|
||||
|
||||
|
||||
async def test_llm_connection():
|
||||
try:
|
||||
llm_adapter = get_llm_client()
|
||||
await llm_adapter.acreate_structured_output(
|
||||
text_input="test",
|
||||
system_prompt='Respond to me with the following string: "test"',
|
||||
response_model=str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to LLM could not be established.")
|
||||
raise e
|
||||
|
||||
|
||||
async def test_embedding_connection():
|
||||
try:
|
||||
await get_vector_engine().embedding_engine.embed_text("test")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error("Connection to Embedding handler could not be established.")
|
||||
raise e
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue