diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 6417f34f7..420eed4a5 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -121,13 +121,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ ) # Set vector and graph database configuration based on dataset database information - # TODO: Add better handling of vector and graph config accross Cognee. + # TODO: Add better handling of vector and graph config across Cognee. # LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter vector_config = { "vector_db_provider": dataset_database.vector_database_provider, "vector_db_url": dataset_database.vector_database_url, "vector_db_key": dataset_database.vector_database_key, "vector_db_name": dataset_database.vector_database_name, + "vector_db_port": dataset_database.vector_database_connection_info.get("port", ""), + "vector_db_username": dataset_database.vector_database_connection_info.get("username", ""), + "vector_db_password": dataset_database.vector_database_connection_info.get("password", ""), } graph_config = { diff --git a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py index 225e9732e..1bac5c4ef 100644 --- a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +++ b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py @@ -7,6 +7,9 @@ from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandle from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import ( KuzuDatasetDatabaseHandler, ) +from cognee.infrastructure.databases.vector.pgvector.PGVectorDatasetDatabaseHandler import ( + PGVectorDatasetDatabaseHandler, +) supported_dataset_database_handlers = { "neo4j_aura_dev": { @@ -14,5 +17,9 @@ supported_dataset_database_handlers = { "handler_provider": "neo4j", }, "lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"}, + "pgvector": { + "handler_instance": PGVectorDatasetDatabaseHandler, + "handler_provider": "pgvector", + }, "kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"}, } diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 37ceb170d..c416ab6b3 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List from contextlib import asynccontextmanager from sqlalchemy.orm import joinedload from sqlalchemy.exc import NoResultFound -from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect +from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect, URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from cognee.modules.data.models.Data import Data @@ -87,6 +87,27 @@ class SQLAlchemyAdapter: connect_args=final_connect_args, ) + from cognee.context_global_variables import backend_access_control_enabled + + if backend_access_control_enabled(): + from cognee.infrastructure.databases.vector.config import get_vectordb_config + + vector_config = get_vectordb_config() + if vector_config.vector_db_provider == "pgvector": + # Create a maintenance engine, used when creating new postgres databases. + # Database named "postgres" should always exist. We need this since the SQLAlchemy + # engine cannot directly execute queries without first connecting to a database. + maintenance_db_name = "postgres" + maintenance_db_url = URL.create( + "postgresql+asyncpg", + username=vector_config.vector_db_username, + password=vector_config.vector_db_password, + host=vector_config.vector_db_url, + port=int(vector_config.vector_db_port), + database=maintenance_db_name, + ) + self.maintenance_engine = create_async_engine(maintenance_db_url) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) async def push_to_s3(self) -> None: @@ -517,9 +538,32 @@ class SQLAlchemyAdapter: if not await file_storage.file_exists(db_name): await file_storage.ensure_directory_exists() - async with self.engine.begin() as connection: - if len(Base.metadata.tables.keys()) > 0: - await connection.run_sync(Base.metadata.create_all) + from cognee.infrastructure.databases.relational.config import get_relational_config + + relational_config = get_relational_config() + + if self.engine.dialect.name == "sqlite" or ( + self.engine.dialect.name == "postgresql" + and relational_config.db_provider == "postgres" + and self.engine.url.database == relational_config.db_name + ): + # In this case we already have a relational db created in sqlite or postgres, we just need to populate it + async with self.engine.begin() as connection: + if len(Base.metadata.tables.keys()) > 0: + await connection.run_sync(Base.metadata.create_all) + return + + from cognee.context_global_variables import backend_access_control_enabled + + if self.engine.dialect.name == "postgresql" and backend_access_control_enabled(): + # Connect to maintenance db in order to create new database + # Make sure to execute CREATE DATABASE outside of transaction block, and set AUTOCOMMIT isolation level + connection = await self.maintenance_engine.connect() + await connection.execution_options(isolation_level="AUTOCOMMIT") + await connection.execute(text(f'CREATE DATABASE "{self.engine.url.database}";')) + + # Clean up resources + await connection.close() async def delete_database(self): """ diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 86b2a0fce..8661401c4 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -29,6 +29,8 @@ class VectorConfig(BaseSettings): vector_db_key: str = "" vector_db_provider: str = "lancedb" vector_dataset_database_handler: str = "lancedb" + vector_db_username: str = "" + vector_db_password: str = "" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -65,6 +67,8 @@ class VectorConfig(BaseSettings): "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, "vector_dataset_database_handler": self.vector_dataset_database_handler, + "vector_db_username": self.vector_db_username, + "vector_db_password": self.vector_db_password, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index 8a87f0339..36c6ef09e 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -15,6 +15,8 @@ def create_vector_engine( vector_db_port: str = "", vector_db_key: str = "", vector_dataset_database_handler: str = "", + vector_db_username: str = "", + vector_db_password: str = "", ): """ Create a vector database engine based on the specified provider. @@ -55,27 +57,43 @@ def create_vector_engine( ) if vector_db_provider.lower() == "pgvector": - from cognee.infrastructure.databases.relational import get_relational_config + from cognee.context_global_variables import backend_access_control_enabled - # Get configuration for postgres database - relational_config = get_relational_config() - db_username = relational_config.db_username - db_password = relational_config.db_password - db_host = relational_config.db_host - db_port = relational_config.db_port - db_name = relational_config.db_name + if backend_access_control_enabled(): + connection_string: str = ( + f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}" + f"@{vector_db_url}:{vector_db_port}/{vector_db_name}" + ) + else: + if ( + vector_db_port + and vector_db_username + and vector_db_password + and vector_db_url + and vector_db_name + ): + connection_string: str = ( + f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}" + f"@{vector_db_url}:{vector_db_port}/{vector_db_name}" + ) + else: + from cognee.infrastructure.databases.relational import get_relational_config - if not (db_host and db_port and db_name and db_username and db_password): - raise EnvironmentError("Missing requred pgvector credentials!") + # Get configuration for postgres database + relational_config = get_relational_config() + db_username = relational_config.db_username + db_password = relational_config.db_password + db_host = relational_config.db_host + db_port = relational_config.db_port + db_name = relational_config.db_name - connection_string = URL.create( - "postgresql+asyncpg", - username=db_username, - password=db_password, - host=db_host, - port=int(db_port), - database=db_name, - ) + if not (db_host and db_port and db_name and db_username and db_password): + raise EnvironmentError("Missing required pgvector credentials!") + + connection_string: str = ( + f"postgresql+asyncpg://{db_username}:{db_password}" + f"@{db_host}:{db_port}/{db_name}" + ) try: from .pgvector.PGVectorAdapter import PGVectorAdapter diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py new file mode 100644 index 000000000..a1834688b --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py @@ -0,0 +1,97 @@ +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class PGVectorDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with PGVector Dataset databases. + """ + + @classmethod + async def _create_pg_database(cls, vector_config): + """ + Create the necessary Postgres database, and the PGVector extension on it. + This is defined here because the creation needs the latest vector config, + which is not yet saved in the vector config context variable here. + """ + from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, + ) + + from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from sqlalchemy import text + + pg_relational_engine = create_relational_engine( + db_path="", + db_host=vector_config["vector_db_url"], + db_name=vector_config["vector_db_name"], + db_port=vector_config["vector_db_port"], + db_username=vector_config["vector_db_username"], + db_password=vector_config["vector_db_password"], + db_provider="postgres", + ) + await pg_relational_engine.create_database() + + vector_engine = create_vector_engine(**vector_config) + async with vector_engine.engine.begin() as connection: + await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + vector_config = get_vectordb_config() + + if vector_config.vector_db_provider != "pgvector": + raise ValueError( + "PGVectorDatasetDatabaseHandler can only be used with PGVector vector database provider." + ) + + vector_db_name = f"{dataset_id}" + + new_vector_config = { + "vector_database_provider": vector_config.vector_db_provider, + "vector_database_url": vector_config.vector_db_url, + "vector_database_name": vector_db_name, + "vector_database_connection_info": { + "port": vector_config.vector_db_port, + "username": vector_config.vector_db_username, + "password": vector_config.vector_db_password, + }, + "vector_dataset_database_handler": "pgvector", + } + + await cls._create_pg_database( + { + "vector_db_provider": new_vector_config["vector_database_provider"], + "vector_db_url": new_vector_config["vector_database_url"], + "vector_db_name": new_vector_config["vector_database_name"], + "vector_db_port": new_vector_config["vector_database_connection_info"]["port"], + "vector_db_key": "", + "vector_db_username": new_vector_config["vector_database_connection_info"][ + "username" + ], + "vector_db_password": new_vector_config["vector_database_connection_info"][ + "password" + ], + "vector_dataset_database_handler": "pgvector", + } + ) + + return new_vector_config + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + vector_engine = create_vector_engine( + vector_db_provider=dataset_database.vector_database_provider, + vector_db_url=dataset_database.vector_database_url, + vector_db_name=dataset_database.vector_database_name, + vector_db_port=dataset_database.vector_database_connection_info["port"], + vector_db_username=dataset_database.vector_database_connection_info["username"], + vector_db_password=dataset_database.vector_database_connection_info["password"], + ) + await vector_engine.prune() diff --git a/cognee/modules/engine/operations/setup.py b/cognee/modules/engine/operations/setup.py index a54d4b949..276ffb60b 100644 --- a/cognee/modules/engine/operations/setup.py +++ b/cognee/modules/engine/operations/setup.py @@ -4,6 +4,7 @@ from cognee.infrastructure.databases.relational import ( from cognee.infrastructure.databases.vector.pgvector import ( create_db_and_tables as create_pgvector_db_and_tables, ) +from cognee.context_global_variables import backend_access_control_enabled async def setup(): @@ -14,4 +15,5 @@ async def setup(): followed by creating a PGVector database and its tables. """ await create_relational_db_and_tables() - await create_pgvector_db_and_tables() + if not backend_access_control_enabled(): + await create_pgvector_db_and_tables() diff --git a/cognee/modules/pipelines/layers/setup_and_check_environment.py b/cognee/modules/pipelines/layers/setup_and_check_environment.py index 55e58ed8a..345acee8b 100644 --- a/cognee/modules/pipelines/layers/setup_and_check_environment.py +++ b/cognee/modules/pipelines/layers/setup_and_check_environment.py @@ -2,6 +2,7 @@ import asyncio from cognee.context_global_variables import ( graph_db_config as context_graph_db_config, vector_db_config as context_vector_db_config, + backend_access_control_enabled, ) from cognee.infrastructure.databases.relational import ( @@ -26,7 +27,8 @@ async def setup_and_check_environment( # Create tables for databases await create_relational_db_and_tables() - await create_pgvector_db_and_tables() + if not backend_access_control_enabled(): + await create_pgvector_db_and_tables() global _first_run_done async with _first_run_lock: