Add pydantic settings checker (#497)

<!-- .github/pull_request_template.md -->

## Description
Add test of embedding and LLM model at beginning of cognee use
Fix issue with relational database async use
Refactor handling of cache mechanism for all databases so changes in
config can be reflected in get functions

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced connection testing for language and embedding services at
startup, ensuring improved reliability during data addition.
  
- **Refactor**
- Streamlined engine initialization across multiple database systems to
enhance performance and clarity.
- Improved parameter handling and caching strategies for faster, more
consistent operations.
  - Updated record identifiers for more robust and unique data storage.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: holchan <61059652+holchan@users.noreply.github.com>
Co-authored-by: Boris <boris@topoteretes.com>
This commit is contained in:
Igor Ilic 2025-02-04 23:18:27 +01:00 committed by GitHub
parent 690d028928
commit df163b0431
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 107 additions and 59 deletions

View file

@ -16,9 +16,22 @@ async def add(
dataset_name: str = "main_dataset",
user: User = None,
):
# Create tables for databases
await create_relational_db_and_tables()
await create_pgvector_db_and_tables()
# Initialize first_run attribute if it doesn't exist
if not hasattr(add, "first_run"):
add.first_run = True
if add.first_run:
from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection
# Test LLM and Embedding configuration once before running Cognee
await test_llm_connection()
await test_embedding_connection()
add.first_run = False # Update flag after first run
if user is None:
user = await get_default_user()

View file

@ -25,11 +25,24 @@ class GraphConfig(BaseSettings):
return {
"graph_filename": self.graph_filename,
"graph_database_provider": self.graph_database_provider,
"graph_file_path": self.graph_file_path,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
"graph_model": self.graph_model,
"graph_topology": self.graph_topology,
"model_config": self.model_config,
}
def to_hashable_dict(self) -> dict:
return {
"graph_database_provider": self.graph_database_provider,
"graph_database_url": self.graph_database_url,
"graph_database_username": self.graph_database_username,
"graph_database_password": self.graph_database_password,
"graph_database_port": self.graph_database_port,
"graph_file_path": self.graph_file_path,
}

View file

@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface
async def get_graph_engine() -> GraphDBInterface:
"""Factory function to get the appropriate graph client based on the graph type."""
graph_client = create_graph_engine()
config = get_graph_config()
graph_client = create_graph_engine(**get_graph_config().to_hashable_dict())
# Async functions can't be cached. After creating and caching the graph engine
# handle all necessary async operations for different graph types bellow.
config = get_graph_config()
# Handle loading of graph for NetworkX
if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None:
await graph_client.load_graph_from_file()
@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface:
@lru_cache
def create_graph_engine() -> GraphDBInterface:
def create_graph_engine(
graph_database_provider,
graph_database_url,
graph_database_username,
graph_database_password,
graph_database_port,
graph_file_path,
):
"""Factory function to create the appropriate graph client based on the graph type."""
config = get_graph_config()
if config.graph_database_provider == "neo4j":
if not (
config.graph_database_url
and config.graph_database_username
and config.graph_database_password
):
if graph_database_provider == "neo4j":
if not (graph_database_url and graph_database_username and graph_database_password):
raise EnvironmentError("Missing required Neo4j credentials.")
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter(
graph_database_url=config.graph_database_url,
graph_database_username=config.graph_database_username,
graph_database_password=config.graph_database_password,
graph_database_url=graph_database_url,
graph_database_username=graph_database_username,
graph_database_password=graph_database_password,
)
elif config.graph_database_provider == "falkordb":
if not (config.graph_database_url and config.graph_database_port):
elif graph_database_provider == "falkordb":
if not (graph_database_url and graph_database_port):
raise EnvironmentError("Missing required FalkorDB credentials.")
from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine
@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface:
embedding_engine = get_embedding_engine()
return FalkorDBAdapter(
database_url=config.graph_database_url,
database_port=config.graph_database_port,
database_url=graph_database_url,
database_port=graph_database_port,
embedding_engine=embedding_engine,
)
from .networkx.adapter import NetworkXAdapter
graph_client = NetworkXAdapter(filename=config.graph_file_path)
graph_client = NetworkXAdapter(filename=graph_file_path)
return graph_client

View file

@ -1,6 +1,8 @@
from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from functools import lru_cache
@lru_cache
def create_relational_engine(
db_path: str,
db_name: str,

View file

@ -1,10 +1,7 @@
# from functools import lru_cache
from .config import get_relational_config
from .create_relational_engine import create_relational_engine
# @lru_cache
def get_relational_engine():
relational_config = get_relational_config()

View file

@ -1,49 +1,47 @@
from typing import Dict
from functools import lru_cache
class VectorConfig(Dict):
vector_db_url: str
vector_db_port: str
vector_db_key: str
vector_db_provider: str
def create_vector_engine(config: VectorConfig, embedding_engine):
if config["vector_db_provider"] == "weaviate":
@lru_cache
def create_vector_engine(
embedding_engine,
vector_db_url: str,
vector_db_port: str,
vector_db_key: str,
vector_db_provider: str,
):
if vector_db_provider == "weaviate":
from .weaviate_db import WeaviateAdapter
if not (config["vector_db_url"] and config["vector_db_key"]):
if not (vector_db_url and vector_db_key):
raise EnvironmentError("Missing requred Weaviate credentials!")
return WeaviateAdapter(
config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine
)
return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine)
elif config["vector_db_provider"] == "qdrant":
if not (config["vector_db_url"] and config["vector_db_key"]):
elif vector_db_provider == "qdrant":
if not (vector_db_url and vector_db_key):
raise EnvironmentError("Missing requred Qdrant credentials!")
from .qdrant.QDrantAdapter import QDrantAdapter
return QDrantAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
elif config["vector_db_provider"] == "milvus":
elif vector_db_provider == "milvus":
from .milvus.MilvusAdapter import MilvusAdapter
if not config["vector_db_url"]:
if not vector_db_url:
raise EnvironmentError("Missing required Milvus credentials!")
return MilvusAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)
elif config["vector_db_provider"] == "pgvector":
elif vector_db_provider == "pgvector":
from cognee.infrastructure.databases.relational import get_relational_config
# Get configuration for postgres database
@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
return PGVectorAdapter(
connection_string,
config["vector_db_key"],
vector_db_key,
embedding_engine,
)
elif config["vector_db_provider"] == "falkordb":
if not (config["vector_db_url"] and config["vector_db_port"]):
elif vector_db_provider == "falkordb":
if not (vector_db_url and vector_db_port):
raise EnvironmentError("Missing requred FalkorDB credentials!")
from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter
return FalkorDBAdapter(
database_url=config["vector_db_url"],
database_port=config["vector_db_port"],
database_url=vector_db_url,
database_port=vector_db_port,
embedding_engine=embedding_engine,
)
@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
from .lancedb.LanceDBAdapter import LanceDBAdapter
return LanceDBAdapter(
url=config["vector_db_url"],
api_key=config["vector_db_key"],
url=vector_db_url,
api_key=vector_db_key,
embedding_engine=embedding_engine,
)

View file

@ -1,9 +1,7 @@
from .config import get_vectordb_config
from .embeddings import get_embedding_engine
from .create_vector_engine import create_vector_engine
from functools import lru_cache
@lru_cache
def get_vector_engine():
return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine())
return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict())

View file

@ -1,6 +1,6 @@
import asyncio
from typing import List, Optional, get_type_hints
from uuid import UUID
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))

View file

@ -1,2 +1,4 @@
from .config import get_llm_config
from .utils import get_max_chunk_tokens
from .utils import test_llm_connection
from .utils import test_embedding_connection

View file

@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str):
logger.info("Model not found in LiteLLM's model_cost.")
return max_tokens
async def test_llm_connection():
try:
llm_adapter = get_llm_client()
await llm_adapter.acreate_structured_output(
text_input="test",
system_prompt='Respond to me with the following string: "test"',
response_model=str,
)
except Exception as e:
logger.error(e)
logger.error("Connection to LLM could not be established.")
raise e
async def test_embedding_connection():
try:
await get_vector_engine().embedding_engine.embed_text("test")
except Exception as e:
logger.error(e)
logger.error("Connection to Embedding handler could not be established.")
raise e