feat: add multi user for PGVector, and enable the usage of PGVector with SQLite
This commit is contained in:
parent
ab990f7c5c
commit
67dfb37709
8 changed files with 202 additions and 25 deletions
|
|
@ -121,13 +121,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
|||
)
|
||||
|
||||
# Set vector and graph database configuration based on dataset database information
|
||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
||||
# TODO: Add better handling of vector and graph config across Cognee.
|
||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||
vector_config = {
|
||||
"vector_db_provider": dataset_database.vector_database_provider,
|
||||
"vector_db_url": dataset_database.vector_database_url,
|
||||
"vector_db_key": dataset_database.vector_database_key,
|
||||
"vector_db_name": dataset_database.vector_database_name,
|
||||
"vector_db_port": dataset_database.vector_database_connection_info.get("port", ""),
|
||||
"vector_db_username": dataset_database.vector_database_connection_info.get("username", ""),
|
||||
"vector_db_password": dataset_database.vector_database_connection_info.get("password", ""),
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandle
|
|||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||
KuzuDatasetDatabaseHandler,
|
||||
)
|
||||
from cognee.infrastructure.databases.vector.pgvector.PGVectorDatasetDatabaseHandler import (
|
||||
PGVectorDatasetDatabaseHandler,
|
||||
)
|
||||
|
||||
supported_dataset_database_handlers = {
|
||||
"neo4j_aura_dev": {
|
||||
|
|
@ -14,5 +17,9 @@ supported_dataset_database_handlers = {
|
|||
"handler_provider": "neo4j",
|
||||
},
|
||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||
"pgvector": {
|
||||
"handler_instance": PGVectorDatasetDatabaseHandler,
|
||||
"handler_provider": "pgvector",
|
||||
},
|
||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List
|
|||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect
|
||||
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect, URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from cognee.modules.data.models.Data import Data
|
||||
|
|
@ -87,6 +87,27 @@ class SQLAlchemyAdapter:
|
|||
connect_args=final_connect_args,
|
||||
)
|
||||
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if backend_access_control_enabled():
|
||||
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||
|
||||
vector_config = get_vectordb_config()
|
||||
if vector_config.vector_db_provider == "pgvector":
|
||||
# Create a maintenance engine, used when creating new postgres databases.
|
||||
# Database named "postgres" should always exist. We need this since the SQLAlchemy
|
||||
# engine cannot directly execute queries without first connecting to a database.
|
||||
maintenance_db_name = "postgres"
|
||||
maintenance_db_url = URL.create(
|
||||
"postgresql+asyncpg",
|
||||
username=vector_config.vector_db_username,
|
||||
password=vector_config.vector_db_password,
|
||||
host=vector_config.vector_db_url,
|
||||
port=int(vector_config.vector_db_port),
|
||||
database=maintenance_db_name,
|
||||
)
|
||||
self.maintenance_engine = create_async_engine(maintenance_db_url)
|
||||
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
async def push_to_s3(self) -> None:
|
||||
|
|
@ -517,9 +538,32 @@ class SQLAlchemyAdapter:
|
|||
if not await file_storage.file_exists(db_name):
|
||||
await file_storage.ensure_directory_exists()
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
from cognee.infrastructure.databases.relational.config import get_relational_config
|
||||
|
||||
relational_config = get_relational_config()
|
||||
|
||||
if self.engine.dialect.name == "sqlite" or (
|
||||
self.engine.dialect.name == "postgresql"
|
||||
and relational_config.db_provider == "postgres"
|
||||
and self.engine.url.database == relational_config.db_name
|
||||
):
|
||||
# In this case we already have a relational db created in sqlite or postgres, we just need to populate it
|
||||
async with self.engine.begin() as connection:
|
||||
if len(Base.metadata.tables.keys()) > 0:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
return
|
||||
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
if self.engine.dialect.name == "postgresql" and backend_access_control_enabled():
|
||||
# Connect to maintenance db in order to create new database
|
||||
# Make sure to execute CREATE DATABASE outside of transaction block, and set AUTOCOMMIT isolation level
|
||||
connection = await self.maintenance_engine.connect()
|
||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await connection.execute(text(f'CREATE DATABASE "{self.engine.url.database}";'))
|
||||
|
||||
# Clean up resources
|
||||
await connection.close()
|
||||
|
||||
async def delete_database(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ class VectorConfig(BaseSettings):
|
|||
vector_db_key: str = ""
|
||||
vector_db_provider: str = "lancedb"
|
||||
vector_dataset_database_handler: str = "lancedb"
|
||||
vector_db_username: str = ""
|
||||
vector_db_password: str = ""
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||
|
||||
|
|
@ -65,6 +67,8 @@ class VectorConfig(BaseSettings):
|
|||
"vector_db_key": self.vector_db_key,
|
||||
"vector_db_provider": self.vector_db_provider,
|
||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||
"vector_db_username": self.vector_db_username,
|
||||
"vector_db_password": self.vector_db_password,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ def create_vector_engine(
|
|||
vector_db_port: str = "",
|
||||
vector_db_key: str = "",
|
||||
vector_dataset_database_handler: str = "",
|
||||
vector_db_username: str = "",
|
||||
vector_db_password: str = "",
|
||||
):
|
||||
"""
|
||||
Create a vector database engine based on the specified provider.
|
||||
|
|
@ -55,27 +57,43 @@ def create_vector_engine(
|
|||
)
|
||||
|
||||
if vector_db_provider.lower() == "pgvector":
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
db_username = relational_config.db_username
|
||||
db_password = relational_config.db_password
|
||||
db_host = relational_config.db_host
|
||||
db_port = relational_config.db_port
|
||||
db_name = relational_config.db_name
|
||||
if backend_access_control_enabled():
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
||||
f"@{vector_db_url}:{vector_db_port}/{vector_db_name}"
|
||||
)
|
||||
else:
|
||||
if (
|
||||
vector_db_port
|
||||
and vector_db_username
|
||||
and vector_db_password
|
||||
and vector_db_url
|
||||
and vector_db_name
|
||||
):
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
||||
f"@{vector_db_url}:{vector_db_port}/{vector_db_name}"
|
||||
)
|
||||
else:
|
||||
from cognee.infrastructure.databases.relational import get_relational_config
|
||||
|
||||
if not (db_host and db_port and db_name and db_username and db_password):
|
||||
raise EnvironmentError("Missing requred pgvector credentials!")
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
db_username = relational_config.db_username
|
||||
db_password = relational_config.db_password
|
||||
db_host = relational_config.db_host
|
||||
db_port = relational_config.db_port
|
||||
db_name = relational_config.db_name
|
||||
|
||||
connection_string = URL.create(
|
||||
"postgresql+asyncpg",
|
||||
username=db_username,
|
||||
password=db_password,
|
||||
host=db_host,
|
||||
port=int(db_port),
|
||||
database=db_name,
|
||||
)
|
||||
if not (db_host and db_port and db_name and db_username and db_password):
|
||||
raise EnvironmentError("Missing required pgvector credentials!")
|
||||
|
||||
connection_string: str = (
|
||||
f"postgresql+asyncpg://{db_username}:{db_password}"
|
||||
f"@{db_host}:{db_port}/{db_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.models import DatasetDatabase
|
||||
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||
|
||||
|
||||
class PGVectorDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||
"""
|
||||
Handler for interacting with PGVector Dataset databases.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
async def _create_pg_database(cls, vector_config):
|
||||
"""
|
||||
Create the necessary Postgres database, and the PGVector extension on it.
|
||||
This is defined here because the creation needs the latest vector config,
|
||||
which is not yet saved in the vector config context variable here.
|
||||
"""
|
||||
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||
create_relational_engine,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||
from sqlalchemy import text
|
||||
|
||||
pg_relational_engine = create_relational_engine(
|
||||
db_path="",
|
||||
db_host=vector_config["vector_db_url"],
|
||||
db_name=vector_config["vector_db_name"],
|
||||
db_port=vector_config["vector_db_port"],
|
||||
db_username=vector_config["vector_db_username"],
|
||||
db_password=vector_config["vector_db_password"],
|
||||
db_provider="postgres",
|
||||
)
|
||||
await pg_relational_engine.create_database()
|
||||
|
||||
vector_engine = create_vector_engine(**vector_config)
|
||||
async with vector_engine.engine.begin() as connection:
|
||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||
|
||||
@classmethod
|
||||
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||
vector_config = get_vectordb_config()
|
||||
|
||||
if vector_config.vector_db_provider != "pgvector":
|
||||
raise ValueError(
|
||||
"PGVectorDatasetDatabaseHandler can only be used with PGVector vector database provider."
|
||||
)
|
||||
|
||||
vector_db_name = f"{dataset_id}"
|
||||
|
||||
new_vector_config = {
|
||||
"vector_database_provider": vector_config.vector_db_provider,
|
||||
"vector_database_url": vector_config.vector_db_url,
|
||||
"vector_database_name": vector_db_name,
|
||||
"vector_database_connection_info": {
|
||||
"port": vector_config.vector_db_port,
|
||||
"username": vector_config.vector_db_username,
|
||||
"password": vector_config.vector_db_password,
|
||||
},
|
||||
"vector_dataset_database_handler": "pgvector",
|
||||
}
|
||||
|
||||
await cls._create_pg_database(
|
||||
{
|
||||
"vector_db_provider": new_vector_config["vector_database_provider"],
|
||||
"vector_db_url": new_vector_config["vector_database_url"],
|
||||
"vector_db_name": new_vector_config["vector_database_name"],
|
||||
"vector_db_port": new_vector_config["vector_database_connection_info"]["port"],
|
||||
"vector_db_key": "",
|
||||
"vector_db_username": new_vector_config["vector_database_connection_info"][
|
||||
"username"
|
||||
],
|
||||
"vector_db_password": new_vector_config["vector_database_connection_info"][
|
||||
"password"
|
||||
],
|
||||
"vector_dataset_database_handler": "pgvector",
|
||||
}
|
||||
)
|
||||
|
||||
return new_vector_config
|
||||
|
||||
@classmethod
|
||||
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||
vector_engine = create_vector_engine(
|
||||
vector_db_provider=dataset_database.vector_database_provider,
|
||||
vector_db_url=dataset_database.vector_database_url,
|
||||
vector_db_name=dataset_database.vector_database_name,
|
||||
vector_db_port=dataset_database.vector_database_connection_info["port"],
|
||||
vector_db_username=dataset_database.vector_database_connection_info["username"],
|
||||
vector_db_password=dataset_database.vector_database_connection_info["password"],
|
||||
)
|
||||
await vector_engine.prune()
|
||||
|
|
@ -4,6 +4,7 @@ from cognee.infrastructure.databases.relational import (
|
|||
from cognee.infrastructure.databases.vector.pgvector import (
|
||||
create_db_and_tables as create_pgvector_db_and_tables,
|
||||
)
|
||||
from cognee.context_global_variables import backend_access_control_enabled
|
||||
|
||||
|
||||
async def setup():
|
||||
|
|
@ -14,4 +15,5 @@ async def setup():
|
|||
followed by creating a PGVector database and its tables.
|
||||
"""
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
if not backend_access_control_enabled():
|
||||
await create_pgvector_db_and_tables()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
from cognee.context_global_variables import (
|
||||
graph_db_config as context_graph_db_config,
|
||||
vector_db_config as context_vector_db_config,
|
||||
backend_access_control_enabled,
|
||||
)
|
||||
|
||||
from cognee.infrastructure.databases.relational import (
|
||||
|
|
@ -26,7 +27,8 @@ async def setup_and_check_environment(
|
|||
|
||||
# Create tables for databases
|
||||
await create_relational_db_and_tables()
|
||||
await create_pgvector_db_and_tables()
|
||||
if not backend_access_control_enabled():
|
||||
await create_pgvector_db_and_tables()
|
||||
|
||||
global _first_run_done
|
||||
async with _first_run_lock:
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue