diff --git a/cognee/api/v1/add/add_v2.py b/cognee/api/v1/add/add_v2.py index dbdb93a59..bd633118d 100644 --- a/cognee/api/v1/add/add_v2.py +++ b/cognee/api/v1/add/add_v2.py @@ -16,9 +16,22 @@ async def add( dataset_name: str = "main_dataset", user: User = None, ): + # Create tables for databases await create_relational_db_and_tables() await create_pgvector_db_and_tables() + # Initialize first_run attribute if it doesn't exist + if not hasattr(add, "first_run"): + add.first_run = True + + if add.first_run: + from cognee.infrastructure.llm.utils import test_llm_connection, test_embedding_connection + + # Test LLM and Embedding configuration once before running Cognee + await test_llm_connection() + await test_embedding_connection() + add.first_run = False # Update flag after first run + if user is None: user = await get_default_user() diff --git a/cognee/infrastructure/databases/graph/config.py b/cognee/infrastructure/databases/graph/config.py index b24a9e964..a17e551b1 100644 --- a/cognee/infrastructure/databases/graph/config.py +++ b/cognee/infrastructure/databases/graph/config.py @@ -25,11 +25,24 @@ class GraphConfig(BaseSettings): return { "graph_filename": self.graph_filename, "graph_database_provider": self.graph_database_provider, - "graph_file_path": self.graph_file_path, "graph_database_url": self.graph_database_url, "graph_database_username": self.graph_database_username, "graph_database_password": self.graph_database_password, "graph_database_port": self.graph_database_port, + "graph_file_path": self.graph_file_path, + "graph_model": self.graph_model, + "graph_topology": self.graph_topology, + "model_config": self.model_config, + } + + def to_hashable_dict(self) -> dict: + return { + "graph_database_provider": self.graph_database_provider, + "graph_database_url": self.graph_database_url, + "graph_database_username": self.graph_database_username, + "graph_database_password": self.graph_database_password, + "graph_database_port": self.graph_database_port, + "graph_file_path": self.graph_file_path, } diff --git a/cognee/infrastructure/databases/graph/get_graph_engine.py b/cognee/infrastructure/databases/graph/get_graph_engine.py index 4660a610f..848057a14 100644 --- a/cognee/infrastructure/databases/graph/get_graph_engine.py +++ b/cognee/infrastructure/databases/graph/get_graph_engine.py @@ -8,12 +8,12 @@ from .graph_db_interface import GraphDBInterface async def get_graph_engine() -> GraphDBInterface: """Factory function to get the appropriate graph client based on the graph type.""" - graph_client = create_graph_engine() + config = get_graph_config() + + graph_client = create_graph_engine(**get_graph_config().to_hashable_dict()) # Async functions can't be cached. After creating and caching the graph engine # handle all necessary async operations for different graph types bellow. - config = get_graph_config() - # Handle loading of graph for NetworkX if config.graph_database_provider.lower() == "networkx" and graph_client.graph is None: await graph_client.load_graph_from_file() @@ -22,28 +22,30 @@ async def get_graph_engine() -> GraphDBInterface: @lru_cache -def create_graph_engine() -> GraphDBInterface: +def create_graph_engine( + graph_database_provider, + graph_database_url, + graph_database_username, + graph_database_password, + graph_database_port, + graph_file_path, +): """Factory function to create the appropriate graph client based on the graph type.""" - config = get_graph_config() - if config.graph_database_provider == "neo4j": - if not ( - config.graph_database_url - and config.graph_database_username - and config.graph_database_password - ): + if graph_database_provider == "neo4j": + if not (graph_database_url and graph_database_username and graph_database_password): raise EnvironmentError("Missing required Neo4j credentials.") from .neo4j_driver.adapter import Neo4jAdapter return Neo4jAdapter( - graph_database_url=config.graph_database_url, - graph_database_username=config.graph_database_username, - graph_database_password=config.graph_database_password, + graph_database_url=graph_database_url, + graph_database_username=graph_database_username, + graph_database_password=graph_database_password, ) - elif config.graph_database_provider == "falkordb": - if not (config.graph_database_url and config.graph_database_port): + elif graph_database_provider == "falkordb": + if not (graph_database_url and graph_database_port): raise EnvironmentError("Missing required FalkorDB credentials.") from cognee.infrastructure.databases.vector.embeddings import get_embedding_engine @@ -52,13 +54,13 @@ def create_graph_engine() -> GraphDBInterface: embedding_engine = get_embedding_engine() return FalkorDBAdapter( - database_url=config.graph_database_url, - database_port=config.graph_database_port, + database_url=graph_database_url, + database_port=graph_database_port, embedding_engine=embedding_engine, ) from .networkx.adapter import NetworkXAdapter - graph_client = NetworkXAdapter(filename=config.graph_file_path) + graph_client = NetworkXAdapter(filename=graph_file_path) return graph_client diff --git a/cognee/infrastructure/databases/relational/create_relational_engine.py b/cognee/infrastructure/databases/relational/create_relational_engine.py index 13a1edc23..054428896 100644 --- a/cognee/infrastructure/databases/relational/create_relational_engine.py +++ b/cognee/infrastructure/databases/relational/create_relational_engine.py @@ -1,6 +1,8 @@ from .sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter +from functools import lru_cache +@lru_cache def create_relational_engine( db_path: str, db_name: str, diff --git a/cognee/infrastructure/databases/relational/get_relational_engine.py b/cognee/infrastructure/databases/relational/get_relational_engine.py index 44aa7213b..6024c7bd0 100644 --- a/cognee/infrastructure/databases/relational/get_relational_engine.py +++ b/cognee/infrastructure/databases/relational/get_relational_engine.py @@ -1,10 +1,7 @@ -# from functools import lru_cache - from .config import get_relational_config from .create_relational_engine import create_relational_engine -# @lru_cache def get_relational_engine(): relational_config = get_relational_config() diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index e61c272e1..34e08d156 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -1,49 +1,47 @@ -from typing import Dict +from functools import lru_cache -class VectorConfig(Dict): - vector_db_url: str - vector_db_port: str - vector_db_key: str - vector_db_provider: str - - -def create_vector_engine(config: VectorConfig, embedding_engine): - if config["vector_db_provider"] == "weaviate": +@lru_cache +def create_vector_engine( + embedding_engine, + vector_db_url: str, + vector_db_port: str, + vector_db_key: str, + vector_db_provider: str, +): + if vector_db_provider == "weaviate": from .weaviate_db import WeaviateAdapter - if not (config["vector_db_url"] and config["vector_db_key"]): + if not (vector_db_url and vector_db_key): raise EnvironmentError("Missing requred Weaviate credentials!") - return WeaviateAdapter( - config["vector_db_url"], config["vector_db_key"], embedding_engine=embedding_engine - ) + return WeaviateAdapter(vector_db_url, vector_db_key, embedding_engine=embedding_engine) - elif config["vector_db_provider"] == "qdrant": - if not (config["vector_db_url"] and config["vector_db_key"]): + elif vector_db_provider == "qdrant": + if not (vector_db_url and vector_db_key): raise EnvironmentError("Missing requred Qdrant credentials!") from .qdrant.QDrantAdapter import QDrantAdapter return QDrantAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) - elif config["vector_db_provider"] == "milvus": + elif vector_db_provider == "milvus": from .milvus.MilvusAdapter import MilvusAdapter - if not config["vector_db_url"]: + if not vector_db_url: raise EnvironmentError("Missing required Milvus credentials!") return MilvusAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) - elif config["vector_db_provider"] == "pgvector": + elif vector_db_provider == "pgvector": from cognee.infrastructure.databases.relational import get_relational_config # Get configuration for postgres database @@ -65,19 +63,19 @@ def create_vector_engine(config: VectorConfig, embedding_engine): return PGVectorAdapter( connection_string, - config["vector_db_key"], + vector_db_key, embedding_engine, ) - elif config["vector_db_provider"] == "falkordb": - if not (config["vector_db_url"] and config["vector_db_port"]): + elif vector_db_provider == "falkordb": + if not (vector_db_url and vector_db_port): raise EnvironmentError("Missing requred FalkorDB credentials!") from ..hybrid.falkordb.FalkorDBAdapter import FalkorDBAdapter return FalkorDBAdapter( - database_url=config["vector_db_url"], - database_port=config["vector_db_port"], + database_url=vector_db_url, + database_port=vector_db_port, embedding_engine=embedding_engine, ) @@ -85,7 +83,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine): from .lancedb.LanceDBAdapter import LanceDBAdapter return LanceDBAdapter( - url=config["vector_db_url"], - api_key=config["vector_db_key"], + url=vector_db_url, + api_key=vector_db_key, embedding_engine=embedding_engine, ) diff --git a/cognee/infrastructure/databases/vector/get_vector_engine.py b/cognee/infrastructure/databases/vector/get_vector_engine.py index 4a3e81d1e..280a55eee 100644 --- a/cognee/infrastructure/databases/vector/get_vector_engine.py +++ b/cognee/infrastructure/databases/vector/get_vector_engine.py @@ -1,9 +1,7 @@ from .config import get_vectordb_config from .embeddings import get_embedding_engine from .create_vector_engine import create_vector_engine -from functools import lru_cache -@lru_cache def get_vector_engine(): - return create_vector_engine(get_vectordb_config().to_dict(), get_embedding_engine()) + return create_vector_engine(get_embedding_engine(), **get_vectordb_config().to_dict()) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 3700fd0fa..e6659d155 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,6 +1,6 @@ import asyncio from typing import List, Optional, get_type_hints -from uuid import UUID +from uuid import UUID, uuid4 from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy import JSON, Column, Table, select, delete, MetaData @@ -69,7 +69,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) @@ -103,7 +103,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): __tablename__ = collection_name __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(self.Vector(vector_size)) diff --git a/cognee/infrastructure/llm/__init__.py b/cognee/infrastructure/llm/__init__.py index 36d7e56ad..b1609b524 100644 --- a/cognee/infrastructure/llm/__init__.py +++ b/cognee/infrastructure/llm/__init__.py @@ -1,2 +1,4 @@ from .config import get_llm_config from .utils import get_max_chunk_tokens +from .utils import test_llm_connection +from .utils import test_embedding_connection diff --git a/cognee/infrastructure/llm/utils.py b/cognee/infrastructure/llm/utils.py index e0aa8945a..d0479fb30 100644 --- a/cognee/infrastructure/llm/utils.py +++ b/cognee/infrastructure/llm/utils.py @@ -36,3 +36,26 @@ def get_model_max_tokens(model_name: str): logger.info("Model not found in LiteLLM's model_cost.") return max_tokens + + +async def test_llm_connection(): + try: + llm_adapter = get_llm_client() + await llm_adapter.acreate_structured_output( + text_input="test", + system_prompt='Respond to me with the following string: "test"', + response_model=str, + ) + except Exception as e: + logger.error(e) + logger.error("Connection to LLM could not be established.") + raise e + + +async def test_embedding_connection(): + try: + await get_vector_engine().embedding_engine.embed_text("test") + except Exception as e: + logger.error(e) + logger.error("Connection to Embedding handler could not be established.") + raise e