diff --git a/.github/workflows/vector_db_tests.yml b/.github/workflows/vector_db_tests.yml index 65b70abe5..6e9e34493 100644 --- a/.github/workflows/vector_db_tests.yml +++ b/.github/workflows/vector_db_tests.yml @@ -103,6 +103,55 @@ jobs: EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} run: uv run python ./cognee/tests/test_pgvector.py + run-pgvector-multi-user-tests: + name: PGVector Multi-User Tests + runs-on: ubuntu-22.04 + if: ${{ inputs.databases == 'all' || contains(inputs.databases, 'postgres') }} + services: + postgres: + image: pgvector/pgvector:pg17 + env: + POSTGRES_USER: cognee + POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }} + POSTGRES_DB: cognee_db + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + steps: + - name: Check out + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Cognee Setup + uses: ./.github/actions/cognee_setup + with: + python-version: ${{ inputs.python-version }} + extra-dependencies: "postgres" + + - name: Run PGVector Permissions Tests + env: + ENV: 'dev' + ENABLE_BACKEND_ACCESS_CONTROL: 'true' + LLM_MODEL: ${{ secrets.LLM_MODEL }} + LLM_ENDPOINT: ${{ secrets.LLM_ENDPOINT }} + LLM_API_KEY: ${{ secrets.LLM_API_KEY }} + LLM_API_VERSION: ${{ secrets.LLM_API_VERSION }} + EMBEDDING_MODEL: ${{ secrets.EMBEDDING_MODEL }} + EMBEDDING_ENDPOINT: ${{ secrets.EMBEDDING_ENDPOINT }} + EMBEDDING_API_KEY: ${{ secrets.EMBEDDING_API_KEY }} + EMBEDDING_API_VERSION: ${{ secrets.EMBEDDING_API_VERSION }} + VECTOR_DB_URL: 127.0.0.1 + VECTOR_DB_PORT: 5432 + VECTOR_DB_USERNAME: cognee + VECTOR_DB_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }} + VECTOR_DATASET_DATABASE_HANDLER: pgvector + run: uv run python ./cognee/tests/test_permissions.py + run-lancedb-tests: name: LanceDB Tests runs-on: ubuntu-22.04 diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index 6417f34f7..420eed4a5 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -121,13 +121,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ ) # Set vector and graph database configuration based on dataset database information - # TODO: Add better handling of vector and graph config accross Cognee. + # TODO: Add better handling of vector and graph config across Cognee. # LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter vector_config = { "vector_db_provider": dataset_database.vector_database_provider, "vector_db_url": dataset_database.vector_database_url, "vector_db_key": dataset_database.vector_database_key, "vector_db_name": dataset_database.vector_database_name, + "vector_db_port": dataset_database.vector_database_connection_info.get("port", ""), + "vector_db_username": dataset_database.vector_database_connection_info.get("username", ""), + "vector_db_password": dataset_database.vector_database_connection_info.get("password", ""), } graph_config = { diff --git a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py index 225e9732e..1bac5c4ef 100644 --- a/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py +++ b/cognee/infrastructure/databases/dataset_database_handler/supported_dataset_database_handlers.py @@ -7,6 +7,9 @@ from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandle from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import ( KuzuDatasetDatabaseHandler, ) +from cognee.infrastructure.databases.vector.pgvector.PGVectorDatasetDatabaseHandler import ( + PGVectorDatasetDatabaseHandler, +) supported_dataset_database_handlers = { "neo4j_aura_dev": { @@ -14,5 +17,9 @@ supported_dataset_database_handlers = { "handler_provider": "neo4j", }, "lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"}, + "pgvector": { + "handler_instance": PGVectorDatasetDatabaseHandler, + "handler_provider": "pgvector", + }, "kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"}, } diff --git a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py index 37ceb170d..9e09e962d 100644 --- a/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py +++ b/cognee/infrastructure/databases/relational/sqlalchemy/SqlAlchemyAdapter.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List from contextlib import asynccontextmanager from sqlalchemy.orm import joinedload from sqlalchemy.exc import NoResultFound -from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect +from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect, URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from cognee.modules.data.models.Data import Data @@ -87,6 +87,27 @@ class SQLAlchemyAdapter: connect_args=final_connect_args, ) + from cognee.context_global_variables import backend_access_control_enabled + + if backend_access_control_enabled(): + from cognee.infrastructure.databases.vector.config import get_vectordb_config + + vector_config = get_vectordb_config() + if vector_config.vector_db_provider == "pgvector": + # Create a maintenance engine, used when creating new postgres databases. + # Database named "postgres" should always exist. We need this since the SQLAlchemy + # engine cannot directly execute queries without first connecting to a database. + maintenance_db_name = "postgres" + maintenance_db_url = URL.create( + "postgresql+asyncpg", + username=vector_config.vector_db_username, + password=vector_config.vector_db_password, + host=vector_config.vector_db_host, + port=int(vector_config.vector_db_port), + database=maintenance_db_name, + ) + self.maintenance_engine = create_async_engine(maintenance_db_url) + self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) async def push_to_s3(self) -> None: @@ -517,9 +538,32 @@ class SQLAlchemyAdapter: if not await file_storage.file_exists(db_name): await file_storage.ensure_directory_exists() - async with self.engine.begin() as connection: - if len(Base.metadata.tables.keys()) > 0: - await connection.run_sync(Base.metadata.create_all) + from cognee.infrastructure.databases.relational.config import get_relational_config + + relational_config = get_relational_config() + + if self.engine.dialect.name == "sqlite" or ( + self.engine.dialect.name == "postgresql" + and relational_config.db_provider == "postgres" + and self.engine.url.database == relational_config.db_name + ): + # In this case we already have a relational db created in sqlite or postgres, we just need to populate it + async with self.engine.begin() as connection: + if len(Base.metadata.tables.keys()) > 0: + await connection.run_sync(Base.metadata.create_all) + return + + from cognee.context_global_variables import backend_access_control_enabled + + if self.engine.dialect.name == "postgresql" and backend_access_control_enabled(): + # Connect to maintenance db in order to create new database + # Make sure to execute CREATE DATABASE outside of transaction block, and set AUTOCOMMIT isolation level + connection = await self.maintenance_engine.connect() + connection = await connection.execution_options(isolation_level="AUTOCOMMIT") + await connection.execute(text(f'CREATE DATABASE "{self.engine.url.database}";')) + + # Clean up resources + await connection.close() async def delete_database(self): """ diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index 86b2a0fce..738ee7adb 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -29,6 +29,9 @@ class VectorConfig(BaseSettings): vector_db_key: str = "" vector_db_provider: str = "lancedb" vector_dataset_database_handler: str = "lancedb" + vector_db_username: str = "" + vector_db_password: str = "" + vector_db_host: str = "" model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -65,6 +68,9 @@ class VectorConfig(BaseSettings): "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, "vector_dataset_database_handler": self.vector_dataset_database_handler, + "vector_db_username": self.vector_db_username, + "vector_db_password": self.vector_db_password, + "vector_db_host": self.vector_db_host, } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index cdf65514f..8982fa06d 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -14,6 +14,9 @@ def create_vector_engine( vector_db_port: str = "", vector_db_key: str = "", vector_dataset_database_handler: str = "", + vector_db_username: str = "", + vector_db_password: str = "", + vector_db_host: str = "", ): """ Wrapper function to call create vector engine with caching. @@ -26,6 +29,9 @@ def create_vector_engine( vector_db_port, vector_db_key, vector_dataset_database_handler, + vector_db_username, + vector_db_password, + vector_db_host, ) @@ -34,9 +40,12 @@ def _create_vector_engine( vector_db_provider: str, vector_db_url: str, vector_db_name: str, - vector_db_port: str = "", - vector_db_key: str = "", - vector_dataset_database_handler: str = "", + vector_db_port: str, + vector_db_key: str, + vector_dataset_database_handler: str, + vector_db_username: str, + vector_db_password: str, + vector_db_host: str, ): """ Create a vector database engine based on the specified provider. @@ -77,27 +86,43 @@ def _create_vector_engine( ) if vector_db_provider.lower() == "pgvector": - from cognee.infrastructure.databases.relational import get_relational_config + from cognee.context_global_variables import backend_access_control_enabled - # Get configuration for postgres database - relational_config = get_relational_config() - db_username = relational_config.db_username - db_password = relational_config.db_password - db_host = relational_config.db_host - db_port = relational_config.db_port - db_name = relational_config.db_name + if backend_access_control_enabled(): + connection_string: str = ( + f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}" + f"@{vector_db_host}:{vector_db_port}/{vector_db_name}" + ) + else: + if ( + vector_db_port + and vector_db_username + and vector_db_password + and vector_db_host + and vector_db_name + ): + connection_string: str = ( + f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}" + f"@{vector_db_host}:{vector_db_port}/{vector_db_name}" + ) + else: + from cognee.infrastructure.databases.relational import get_relational_config - if not (db_host and db_port and db_name and db_username and db_password): - raise EnvironmentError("Missing requred pgvector credentials!") + # Get configuration for postgres database + relational_config = get_relational_config() + db_username = relational_config.db_username + db_password = relational_config.db_password + db_host = relational_config.db_host + db_port = relational_config.db_port + db_name = relational_config.db_name - connection_string = URL.create( - "postgresql+asyncpg", - username=db_username, - password=db_password, - host=db_host, - port=int(db_port), - database=db_name, - ) + if not (db_host and db_port and db_name and db_username and db_password): + raise EnvironmentError("Missing required pgvector credentials!") + + connection_string: str = ( + f"postgresql+asyncpg://{db_username}:{db_password}" + f"@{db_host}:{db_port}/{db_name}" + ) try: from .pgvector.PGVectorAdapter import PGVectorAdapter diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py new file mode 100644 index 000000000..c9da5b12a --- /dev/null +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorDatasetDatabaseHandler.py @@ -0,0 +1,81 @@ +from uuid import UUID +from typing import Optional + +from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine +from cognee.modules.users.models import User +from cognee.modules.users.models import DatasetDatabase +from cognee.infrastructure.databases.vector import get_vectordb_config +from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface + + +class PGVectorDatasetDatabaseHandler(DatasetDatabaseHandlerInterface): + """ + Handler for interacting with PGVector Dataset databases. + """ + + @classmethod + async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict: + vector_config = get_vectordb_config() + + if vector_config.vector_db_provider != "pgvector": + raise ValueError( + "PGVectorDatasetDatabaseHandler can only be used with PGVector vector database provider." + ) + + vector_db_name = f"{dataset_id}" + + new_vector_config = { + "vector_database_provider": vector_config.vector_db_provider, + "vector_database_url": vector_config.vector_db_url, + "vector_database_name": vector_db_name, + "vector_database_connection_info": { + "port": vector_config.vector_db_port, + "host": vector_config.vector_db_host, + }, + "vector_dataset_database_handler": "pgvector", + } + + from .create_db_and_tables import create_pg_database + + await create_pg_database( + { + "vector_db_provider": new_vector_config["vector_database_provider"], + "vector_db_url": new_vector_config["vector_database_url"], + "vector_db_name": new_vector_config["vector_database_name"], + "vector_db_port": new_vector_config["vector_database_connection_info"]["port"], + "vector_db_key": "", + "vector_db_username": vector_config.vector_db_username, + "vector_db_password": vector_config.vector_db_password, + "vector_db_host": new_vector_config["vector_database_connection_info"]["host"], + "vector_dataset_database_handler": "pgvector", + } + ) + + return new_vector_config + + @classmethod + async def resolve_dataset_connection_info( + cls, dataset_database: DatasetDatabase + ) -> DatasetDatabase: + vector_config = get_vectordb_config() + # Note: For PGVector, we use the vector DB username/password from configuration so it's never stored in the DB + dataset_database.vector_database_connection_info["vector_db_username"] = ( + vector_config.vector_db_username + ) + dataset_database.vector_database_connection_info["vector_db_password"] = ( + vector_config.vector_db_password + ) + return dataset_database + + @classmethod + async def delete_dataset(cls, dataset_database: DatasetDatabase): + vector_config = get_vectordb_config() + vector_engine = create_vector_engine( + vector_db_provider=dataset_database.vector_database_provider, + vector_db_url=dataset_database.vector_database_url, + vector_db_name=dataset_database.vector_database_name, + vector_db_port=dataset_database.vector_database_connection_info["port"], + vector_db_username=vector_config.vector_db_username, + vector_db_password=vector_config.vector_db_password, + ) + await vector_engine.prune() diff --git a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py index c2477086d..a1538f98a 100644 --- a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +++ b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py @@ -10,3 +10,34 @@ async def create_db_and_tables(): if vector_config["vector_db_provider"] == "pgvector": async with vector_engine.engine.begin() as connection: await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + +async def create_pg_database(vector_config): + """ + Create the necessary Postgres database, and the PGVector extension on it. + This is defined separately because the creation needs the latest vector config, + which is not yet saved in the vector config context variable. + + TODO: We can maybe merge this with create_db_and_tables(), but it seemed simpler to separate them for now + """ + from cognee.infrastructure.databases.relational.create_relational_engine import ( + create_relational_engine, + ) + + from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine + from sqlalchemy import text + + pg_relational_engine = create_relational_engine( + db_path="", + db_host=vector_config["vector_db_host"], + db_name=vector_config["vector_db_name"], + db_port=vector_config["vector_db_port"], + db_username=vector_config["vector_db_username"], + db_password=vector_config["vector_db_password"], + db_provider="postgres", + ) + await pg_relational_engine.create_database() + + vector_engine = create_vector_engine(**vector_config) + async with vector_engine.engine.begin() as connection: + await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) diff --git a/cognee/modules/engine/operations/setup.py b/cognee/modules/engine/operations/setup.py index 4992931f2..a1b9a3751 100644 --- a/cognee/modules/engine/operations/setup.py +++ b/cognee/modules/engine/operations/setup.py @@ -4,6 +4,7 @@ from cognee.infrastructure.databases.relational import ( from cognee.infrastructure.databases.vector.pgvector import ( create_db_and_tables as create_pgvector_db_and_tables, ) +from cognee.context_global_variables import backend_access_control_enabled async def setup(): @@ -14,7 +15,8 @@ async def setup(): followed by creating a PGVector database and its tables. """ await create_relational_db_and_tables() - await create_pgvector_db_and_tables() + if not backend_access_control_enabled(): + await create_pgvector_db_and_tables() if __name__ == "__main__": diff --git a/cognee/modules/pipelines/layers/setup_and_check_environment.py b/cognee/modules/pipelines/layers/setup_and_check_environment.py index 55e58ed8a..345acee8b 100644 --- a/cognee/modules/pipelines/layers/setup_and_check_environment.py +++ b/cognee/modules/pipelines/layers/setup_and_check_environment.py @@ -2,6 +2,7 @@ import asyncio from cognee.context_global_variables import ( graph_db_config as context_graph_db_config, vector_db_config as context_vector_db_config, + backend_access_control_enabled, ) from cognee.infrastructure.databases.relational import ( @@ -26,7 +27,8 @@ async def setup_and_check_environment( # Create tables for databases await create_relational_db_and_tables() - await create_pgvector_db_and_tables() + if not backend_access_control_enabled(): + await create_pgvector_db_and_tables() global _first_run_done async with _first_run_lock: