From df163b043195cc830560dcd8dc803edf14cac2da Mon Sep 17 00:00:00 2001 From: Igor Ilic <30923996+dexters1@users.noreply.github.com> Date: Tue, 4 Feb 2025 23:18:27 +0100 Subject: [PATCH] Add pydantic settings checker (#497) ## Description Add test of embedding and LLM model at beginning of cognee use Fix issue with relational database async use Refactor handling of cache mechanism for all databases so changes in config can be reflected in get functions ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin ## Summary by CodeRabbit - **New Features** - Introduced connection testing for language and embedding services at startup, ensuring improved reliability during data addition. - **Refactor** - Streamlined engine initialization across multiple database systems to enhance performance and clarity. - Improved parameter handling and caching strategies for faster, more consistent operations. - Updated record identifiers for more robust and unique data storage. --------- Co-authored-by: holchan <61059652+holchan@users.noreply.github.com> Co-authored-by: Boris --- cognee/api/v1/add/add_v2.py | 13 +++++ .../infrastructure/databases/graph/config.py | 15 ++++- .../databases/graph/get_graph_engine.py | 40 +++++++------ .../relational/create_relational_engine.py | 2 + .../relational/get_relational_engine.py | 3 - .../databases/vector/create_vector_engine.py | 58 +++++++++---------- .../databases/vector/get_vector_engine.py | 4 +- .../vector/pgvector/PGVectorAdapter.py | 6 +- cognee/infrastructure/llm/__init__.py | 2 + cognee/infrastructure/llm/utils.py | 23 ++++++++ 10 files changed, 107 insertions(+), 59 deletions(-) diff --git a/cognee/api/v1/add/add_v2.py b/cognee/api/v1/add/add_v2.py index dbdb93a59..bd633118d 100644 --- a/cognee/api/v1/add/add_v2.py +++ b/cognee/api/v1/add/add_v2.py @@ -16,9 +16,22 @@ async def add( dataset_name: str = "main_dataset", user: User = None, ): + # Create tables for databases await create_relational_db_and_tables() await create_pgvector_db_and_tables() + # Initialize first_run attribute if it doesn't exist + if not hasattr(add, "first_run"): + add.first_run = True + + if add.first_run: + from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection + + # Test LLM and Embedding configuration once before running Cognee + await test_llm_connection() + await test_embedding_connection() + add.first_run = False # Update flag after first run + if user is None: user = await get_default_user() diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index b24a9e964..a17e551b1 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -25,11 +25,24 @@ class GraphConfig(BaseSettings): return { "graph_filename": self.graph_filename, "graph_database_provider": self.graph_database_provider, - "graph_file_path": self.graph_file_path, "graph_database_url": self.graph_database_url, "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_file_path": self.graph_file_path, + "graph_model": self.graph_model, + "graph_topology": self.graph_topology, + "model_config": self.model_config, + } + + def to_hashable_dict(self) -> dict: + return { + "graph_database_provider": self.graph_database_provider, + "graph_database_url": self.graph_database_url, + "graph_database_username": self.graph_database_username, + "graph_database_password": self.graph_database_password, + "graph_database_port": self.graph_database_port, + "graph_file_path": self.graph_file_path, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 4660a610f..848057a14 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface async def get_graph_engine() -> GraphDBInterface: """Factory function to get the appropriate graph client based on the graph type.""" - graph_client = create_graph_engine() + config = get_graph_config() + + graph_client = create_graph_engine(**get_graph_config().to_hashable_dict()) # Async functions can't be cached. After creating and caching the graph engine # handle all necessary async operations for different graph types bellow. - config = get_graph_config() - # Handle loading of graph for NetworkX if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None: await graph_client.load_graph_from_file() @@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface: @lru_cache -def create_graph_engine() -> GraphDBInterface: +def create_graph_engine( + graph_database_provider, + graph_database_url, + graph_database_username, + graph_database_password, + graph_database_port, + graph_file_path, +): """Factory function to create the appropriate graph client based on the graph type.""" - config = get_graph_config() - if config.graph_database_provider == "neo4j": - if not ( - config.graph_database_url - and config.graph_database_username - and config.graph_database_password - ): + if graph_database_provider == "neo4j": + if not (graph_database_url and graph_database_username and graph_database_password): raise EnvironmentError("Missing required Neo4j credentials.") from .neo4j_driver.adapter import Neo4jAdapter return Neo4jAdapter( - graph_database_url=config.graph_database_url, - graph_database_username=config.graph_database_username, - graph_database_password=config.graph_database_password, + graph_database_url=graph_database_url, + graph_database_username=graph_database_username, + graph_database_password=graph_database_password, ) - elif config.graph_database_provider == "falkordb": - if not (config.graph_database_url and config.graph_database_port): + elif graph_database_provider == "falkordb": + if not (graph_database_url and graph_database_port): raise EnvironmentError("Missing required FalkorDB credentials.") from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine @@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface: embedding_engine = get_embedding_engine() return FalkorDBAdapter( - database_url=config.graph_database_url, - database_port=config.graph_database_port, + database_url=graph_database_url, + database_port=graph_database_port, embedding_engine=embedding_engine, ) from .networkx.adapter import NetworkXAdapter - graph_client = NetworkXAdapter(filename=config.graph_file_path) + graph_client = NetworkXAdapter(filename=graph_file_path) return graph_client diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index 13a1edc23..054428896 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -1,6 +1,8 @@ from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from functools import lru_cache +@lru_cache def create_relational_engine( db_path: str, db_name: str, diff --git a/cognee/infrastructure/databases/relational/get_relational_engine.py b/cognee/infrastructure/databases/relational/get_relational_engine.py index 44aa7213b..6024c7bd0 100644 --- a/cognee/infrastructure/databases/relational/get_relational_engine.py +++ b/cognee/infrastructure/databases/relational/get_relational_engine.py @@ -1,10 +1,7 @@ -# from functools import lru_cache - from .config import get_relational_config from .create_relational_engine import create_relational_engine -# @lru_cache def get_relational_engine(): relational_config = get_relational_config() diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index e61c272e1..34e08d156 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,49 +1,47 @@ -from typing import Dict +from functools import lru_cache -class VectorConfig(Dict): - vector_db_url: str - vector_db_port: str - vector_db_key: str - vector_db_provider: str - - -def create_vector_engine(config: VectorConfig, embedding_engine): - if config["vector_db_provider"] == "weaviate": +@lru_cache +def create_vector_engine( + embedding_engine, + vector_db_url: str, + vector_db_port: str, + vector_db_key: str, + vector_db_provider: str, +): + if vector_db_provider == "weaviate": from .weaviate_db import WeaviateAdapter - if not (config["vector_db_url"] and config["vector_db_key"]): + if not (vector_db_url and vector_db_key): raise EnvironmentError("Missing requred Weaviate credentials!") - return WeaviateAdapter( - config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine - ) + return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine) - elif config["vector_db_provider"] == "qdrant": - if not (config["vector_db_url"] and config["vector_db_key"]): + elif vector_db_provider == "qdrant": + if not (vector_db_url and vector_db_key): raise EnvironmentError("Missing requred Qdrant credentials!") from .qdrant.QDrantAdapter import QDrantAdapter return QDrantAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) - elif config["vector_db_provider"] == "milvus": + elif vector_db_provider == "milvus": from .milvus.MilvusAdapter import MilvusAdapter - if not config["vector_db_url"]: + if not vector_db_url: raise EnvironmentError("Missing required Milvus credentials!") return MilvusAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) - elif config["vector_db_provider"] == "pgvector": + elif vector_db_provider == "pgvector": from cognee.infrastructure.databases.relational import get_relational_config # Get configuration for postgres database @@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine): return PGVectorAdapter( connection_string, - config["vector_db_key"], + vector_db_key, embedding_engine, ) - elif config["vector_db_provider"] == "falkordb": - if not (config["vector_db_url"] and config["vector_db_port"]): + elif vector_db_provider == "falkordb": + if not (vector_db_url and vector_db_port): raise EnvironmentError("Missing requred FalkorDB credentials!") from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter return FalkorDBAdapter( - database_url=config["vector_db_url"], - database_port=config["vector_db_port"], + database_url=vector_db_url, + database_port=vector_db_port, embedding_engine=embedding_engine, ) @@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine): from .lancedb.LanceDBAdapter import LanceDBAdapter return LanceDBAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) diff --git a/cognee/infrastructure/databases/vector/get_vector_engine.py b/cognee/infrastructure/databases/vector/get_vector_engine.py index 4a3e81d1e..280a55eee 100644 --- a/cognee/infrastructure/databases/vector/get_vector_engine.py +++ b/cognee/infrastructure/databases/vector/get_vector_engine.py @@ -1,9 +1,7 @@ from .config import get_vectordb_config from .embeddings import get_embedding_engine from .create_vector_engine import create_vector_engine -from functools import lru_cache -@lru_cache def get_vector_engine(): - return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine()) + return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict()) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 3700fd0fa..e6659d155 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,6 +1,6 @@ import asyncio from typing import List, Optional, get_type_hints -from uuid import UUID +from uuid import UUID, uuid4 from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import JSON, Column, Table, select, delete, MetaData @@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) @@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index 36d7e56ad..b1609b524 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1,2 +1,4 @@ from .config import get_llm_config from .utils import get_max_chunk_tokens +from .utils import test_llm_connection +from .utils import test_embedding_connection diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py index e0aa8945a..d0479fb30 100644 --- a/cognee/infrastructure/llm/utils.py +++ b/cognee/infrastructure/llm/utils.py @@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str): logger.info("Model not found in LiteLLM's model_cost.") return max_tokens + + +async def test_llm_connection(): + try: + llm_adapter = get_llm_client() + await llm_adapter.acreate_structured_output( + text_input="test", + system_prompt='Respond to me with the following string: "test"', + response_model=str, + ) + except Exception as e: + logger.error(e) + logger.error("Connection to LLM could not be established.") + raise e + + +async def test_embedding_connection(): + try: + await get_vector_engine().embedding_engine.embed_text("test") + except Exception as e: + logger.error(e) + logger.error("Connection to Embedding handler could not be established.") + raise e