feat: add multi user for PGVector, and enable the usage of PGVector with SQLite
This commit is contained in:
parent
ab990f7c5c
commit
67dfb37709
8 changed files with 202 additions and 25 deletions
|
|
@ -121,13 +121,16 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set vector and graph database configuration based on dataset database information
|
# Set vector and graph database configuration based on dataset database information
|
||||||
# TODO: Add better handling of vector and graph config accross Cognee.
|
# TODO: Add better handling of vector and graph config across Cognee.
|
||||||
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
# LRU_CACHE takes into account order of inputs, if order of inputs is changed it will be registered as a new DB adapter
|
||||||
vector_config = {
|
vector_config = {
|
||||||
"vector_db_provider": dataset_database.vector_database_provider,
|
"vector_db_provider": dataset_database.vector_database_provider,
|
||||||
"vector_db_url": dataset_database.vector_database_url,
|
"vector_db_url": dataset_database.vector_database_url,
|
||||||
"vector_db_key": dataset_database.vector_database_key,
|
"vector_db_key": dataset_database.vector_database_key,
|
||||||
"vector_db_name": dataset_database.vector_database_name,
|
"vector_db_name": dataset_database.vector_database_name,
|
||||||
|
"vector_db_port": dataset_database.vector_database_connection_info.get("port", ""),
|
||||||
|
"vector_db_username": dataset_database.vector_database_connection_info.get("username", ""),
|
||||||
|
"vector_db_password": dataset_database.vector_database_connection_info.get("password", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_config = {
|
graph_config = {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,9 @@ from cognee.infrastructure.databases.vector.lancedb.LanceDBDatasetDatabaseHandle
|
||||||
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
from cognee.infrastructure.databases.graph.kuzu.KuzuDatasetDatabaseHandler import (
|
||||||
KuzuDatasetDatabaseHandler,
|
KuzuDatasetDatabaseHandler,
|
||||||
)
|
)
|
||||||
|
from cognee.infrastructure.databases.vector.pgvector.PGVectorDatasetDatabaseHandler import (
|
||||||
|
PGVectorDatasetDatabaseHandler,
|
||||||
|
)
|
||||||
|
|
||||||
supported_dataset_database_handlers = {
|
supported_dataset_database_handlers = {
|
||||||
"neo4j_aura_dev": {
|
"neo4j_aura_dev": {
|
||||||
|
|
@ -14,5 +17,9 @@ supported_dataset_database_handlers = {
|
||||||
"handler_provider": "neo4j",
|
"handler_provider": "neo4j",
|
||||||
},
|
},
|
||||||
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
"lancedb": {"handler_instance": LanceDBDatasetDatabaseHandler, "handler_provider": "lancedb"},
|
||||||
|
"pgvector": {
|
||||||
|
"handler_instance": PGVectorDatasetDatabaseHandler,
|
||||||
|
"handler_provider": "pgvector",
|
||||||
|
},
|
||||||
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
"kuzu": {"handler_instance": KuzuDatasetDatabaseHandler, "handler_provider": "kuzu"},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.exc import NoResultFound
|
from sqlalchemy.exc import NoResultFound
|
||||||
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect
|
from sqlalchemy import NullPool, text, select, MetaData, Table, delete, inspect, URL
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
from cognee.modules.data.models.Data import Data
|
from cognee.modules.data.models.Data import Data
|
||||||
|
|
@ -87,6 +87,27 @@ class SQLAlchemyAdapter:
|
||||||
connect_args=final_connect_args,
|
connect_args=final_connect_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
if backend_access_control_enabled():
|
||||||
|
from cognee.infrastructure.databases.vector.config import get_vectordb_config
|
||||||
|
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
if vector_config.vector_db_provider == "pgvector":
|
||||||
|
# Create a maintenance engine, used when creating new postgres databases.
|
||||||
|
# Database named "postgres" should always exist. We need this since the SQLAlchemy
|
||||||
|
# engine cannot directly execute queries without first connecting to a database.
|
||||||
|
maintenance_db_name = "postgres"
|
||||||
|
maintenance_db_url = URL.create(
|
||||||
|
"postgresql+asyncpg",
|
||||||
|
username=vector_config.vector_db_username,
|
||||||
|
password=vector_config.vector_db_password,
|
||||||
|
host=vector_config.vector_db_url,
|
||||||
|
port=int(vector_config.vector_db_port),
|
||||||
|
database=maintenance_db_name,
|
||||||
|
)
|
||||||
|
self.maintenance_engine = create_async_engine(maintenance_db_url)
|
||||||
|
|
||||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||||
|
|
||||||
async def push_to_s3(self) -> None:
|
async def push_to_s3(self) -> None:
|
||||||
|
|
@ -517,9 +538,32 @@ class SQLAlchemyAdapter:
|
||||||
if not await file_storage.file_exists(db_name):
|
if not await file_storage.file_exists(db_name):
|
||||||
await file_storage.ensure_directory_exists()
|
await file_storage.ensure_directory_exists()
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
from cognee.infrastructure.databases.relational.config import get_relational_config
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
|
||||||
await connection.run_sync(Base.metadata.create_all)
|
relational_config = get_relational_config()
|
||||||
|
|
||||||
|
if self.engine.dialect.name == "sqlite" or (
|
||||||
|
self.engine.dialect.name == "postgresql"
|
||||||
|
and relational_config.db_provider == "postgres"
|
||||||
|
and self.engine.url.database == relational_config.db_name
|
||||||
|
):
|
||||||
|
# In this case we already have a relational db created in sqlite or postgres, we just need to populate it
|
||||||
|
async with self.engine.begin() as connection:
|
||||||
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
|
await connection.run_sync(Base.metadata.create_all)
|
||||||
|
return
|
||||||
|
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
if self.engine.dialect.name == "postgresql" and backend_access_control_enabled():
|
||||||
|
# Connect to maintenance db in order to create new database
|
||||||
|
# Make sure to execute CREATE DATABASE outside of transaction block, and set AUTOCOMMIT isolation level
|
||||||
|
connection = await self.maintenance_engine.connect()
|
||||||
|
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||||
|
await connection.execute(text(f'CREATE DATABASE "{self.engine.url.database}";'))
|
||||||
|
|
||||||
|
# Clean up resources
|
||||||
|
await connection.close()
|
||||||
|
|
||||||
async def delete_database(self):
|
async def delete_database(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,8 @@ class VectorConfig(BaseSettings):
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_provider: str = "lancedb"
|
vector_db_provider: str = "lancedb"
|
||||||
vector_dataset_database_handler: str = "lancedb"
|
vector_dataset_database_handler: str = "lancedb"
|
||||||
|
vector_db_username: str = ""
|
||||||
|
vector_db_password: str = ""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -65,6 +67,8 @@ class VectorConfig(BaseSettings):
|
||||||
"vector_db_key": self.vector_db_key,
|
"vector_db_key": self.vector_db_key,
|
||||||
"vector_db_provider": self.vector_db_provider,
|
"vector_db_provider": self.vector_db_provider,
|
||||||
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
"vector_dataset_database_handler": self.vector_dataset_database_handler,
|
||||||
|
"vector_db_username": self.vector_db_username,
|
||||||
|
"vector_db_password": self.vector_db_password,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ def create_vector_engine(
|
||||||
vector_db_port: str = "",
|
vector_db_port: str = "",
|
||||||
vector_db_key: str = "",
|
vector_db_key: str = "",
|
||||||
vector_dataset_database_handler: str = "",
|
vector_dataset_database_handler: str = "",
|
||||||
|
vector_db_username: str = "",
|
||||||
|
vector_db_password: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a vector database engine based on the specified provider.
|
Create a vector database engine based on the specified provider.
|
||||||
|
|
@ -55,27 +57,43 @@ def create_vector_engine(
|
||||||
)
|
)
|
||||||
|
|
||||||
if vector_db_provider.lower() == "pgvector":
|
if vector_db_provider.lower() == "pgvector":
|
||||||
from cognee.infrastructure.databases.relational import get_relational_config
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
# Get configuration for postgres database
|
if backend_access_control_enabled():
|
||||||
relational_config = get_relational_config()
|
connection_string: str = (
|
||||||
db_username = relational_config.db_username
|
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
||||||
db_password = relational_config.db_password
|
f"@{vector_db_url}:{vector_db_port}/{vector_db_name}"
|
||||||
db_host = relational_config.db_host
|
)
|
||||||
db_port = relational_config.db_port
|
else:
|
||||||
db_name = relational_config.db_name
|
if (
|
||||||
|
vector_db_port
|
||||||
|
and vector_db_username
|
||||||
|
and vector_db_password
|
||||||
|
and vector_db_url
|
||||||
|
and vector_db_name
|
||||||
|
):
|
||||||
|
connection_string: str = (
|
||||||
|
f"postgresql+asyncpg://{vector_db_username}:{vector_db_password}"
|
||||||
|
f"@{vector_db_url}:{vector_db_port}/{vector_db_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from cognee.infrastructure.databases.relational import get_relational_config
|
||||||
|
|
||||||
if not (db_host and db_port and db_name and db_username and db_password):
|
# Get configuration for postgres database
|
||||||
raise EnvironmentError("Missing requred pgvector credentials!")
|
relational_config = get_relational_config()
|
||||||
|
db_username = relational_config.db_username
|
||||||
|
db_password = relational_config.db_password
|
||||||
|
db_host = relational_config.db_host
|
||||||
|
db_port = relational_config.db_port
|
||||||
|
db_name = relational_config.db_name
|
||||||
|
|
||||||
connection_string = URL.create(
|
if not (db_host and db_port and db_name and db_username and db_password):
|
||||||
"postgresql+asyncpg",
|
raise EnvironmentError("Missing required pgvector credentials!")
|
||||||
username=db_username,
|
|
||||||
password=db_password,
|
connection_string: str = (
|
||||||
host=db_host,
|
f"postgresql+asyncpg://{db_username}:{db_password}"
|
||||||
port=int(db_port),
|
f"@{db_host}:{db_port}/{db_name}"
|
||||||
database=db_name,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from cognee.modules.users.models import User
|
||||||
|
from cognee.modules.users.models import DatasetDatabase
|
||||||
|
from cognee.infrastructure.databases.vector import get_vectordb_config
|
||||||
|
from cognee.infrastructure.databases.dataset_database_handler import DatasetDatabaseHandlerInterface
|
||||||
|
|
||||||
|
|
||||||
|
class PGVectorDatasetDatabaseHandler(DatasetDatabaseHandlerInterface):
|
||||||
|
"""
|
||||||
|
Handler for interacting with PGVector Dataset databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _create_pg_database(cls, vector_config):
|
||||||
|
"""
|
||||||
|
Create the necessary Postgres database, and the PGVector extension on it.
|
||||||
|
This is defined here because the creation needs the latest vector config,
|
||||||
|
which is not yet saved in the vector config context variable here.
|
||||||
|
"""
|
||||||
|
from cognee.infrastructure.databases.relational.create_relational_engine import (
|
||||||
|
create_relational_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from cognee.infrastructure.databases.vector.create_vector_engine import create_vector_engine
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
pg_relational_engine = create_relational_engine(
|
||||||
|
db_path="",
|
||||||
|
db_host=vector_config["vector_db_url"],
|
||||||
|
db_name=vector_config["vector_db_name"],
|
||||||
|
db_port=vector_config["vector_db_port"],
|
||||||
|
db_username=vector_config["vector_db_username"],
|
||||||
|
db_password=vector_config["vector_db_password"],
|
||||||
|
db_provider="postgres",
|
||||||
|
)
|
||||||
|
await pg_relational_engine.create_database()
|
||||||
|
|
||||||
|
vector_engine = create_vector_engine(**vector_config)
|
||||||
|
async with vector_engine.engine.begin() as connection:
|
||||||
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_dataset(cls, dataset_id: Optional[UUID], user: Optional[User]) -> dict:
|
||||||
|
vector_config = get_vectordb_config()
|
||||||
|
|
||||||
|
if vector_config.vector_db_provider != "pgvector":
|
||||||
|
raise ValueError(
|
||||||
|
"PGVectorDatasetDatabaseHandler can only be used with PGVector vector database provider."
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_name = f"{dataset_id}"
|
||||||
|
|
||||||
|
new_vector_config = {
|
||||||
|
"vector_database_provider": vector_config.vector_db_provider,
|
||||||
|
"vector_database_url": vector_config.vector_db_url,
|
||||||
|
"vector_database_name": vector_db_name,
|
||||||
|
"vector_database_connection_info": {
|
||||||
|
"port": vector_config.vector_db_port,
|
||||||
|
"username": vector_config.vector_db_username,
|
||||||
|
"password": vector_config.vector_db_password,
|
||||||
|
},
|
||||||
|
"vector_dataset_database_handler": "pgvector",
|
||||||
|
}
|
||||||
|
|
||||||
|
await cls._create_pg_database(
|
||||||
|
{
|
||||||
|
"vector_db_provider": new_vector_config["vector_database_provider"],
|
||||||
|
"vector_db_url": new_vector_config["vector_database_url"],
|
||||||
|
"vector_db_name": new_vector_config["vector_database_name"],
|
||||||
|
"vector_db_port": new_vector_config["vector_database_connection_info"]["port"],
|
||||||
|
"vector_db_key": "",
|
||||||
|
"vector_db_username": new_vector_config["vector_database_connection_info"][
|
||||||
|
"username"
|
||||||
|
],
|
||||||
|
"vector_db_password": new_vector_config["vector_database_connection_info"][
|
||||||
|
"password"
|
||||||
|
],
|
||||||
|
"vector_dataset_database_handler": "pgvector",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_vector_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_dataset(cls, dataset_database: DatasetDatabase):
|
||||||
|
vector_engine = create_vector_engine(
|
||||||
|
vector_db_provider=dataset_database.vector_database_provider,
|
||||||
|
vector_db_url=dataset_database.vector_database_url,
|
||||||
|
vector_db_name=dataset_database.vector_database_name,
|
||||||
|
vector_db_port=dataset_database.vector_database_connection_info["port"],
|
||||||
|
vector_db_username=dataset_database.vector_database_connection_info["username"],
|
||||||
|
vector_db_password=dataset_database.vector_database_connection_info["password"],
|
||||||
|
)
|
||||||
|
await vector_engine.prune()
|
||||||
|
|
@ -4,6 +4,7 @@ from cognee.infrastructure.databases.relational import (
|
||||||
from cognee.infrastructure.databases.vector.pgvector import (
|
from cognee.infrastructure.databases.vector.pgvector import (
|
||||||
create_db_and_tables as create_pgvector_db_and_tables,
|
create_db_and_tables as create_pgvector_db_and_tables,
|
||||||
)
|
)
|
||||||
|
from cognee.context_global_variables import backend_access_control_enabled
|
||||||
|
|
||||||
|
|
||||||
async def setup():
|
async def setup():
|
||||||
|
|
@ -14,4 +15,5 @@ async def setup():
|
||||||
followed by creating a PGVector database and its tables.
|
followed by creating a PGVector database and its tables.
|
||||||
"""
|
"""
|
||||||
await create_relational_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
await create_pgvector_db_and_tables()
|
if not backend_access_control_enabled():
|
||||||
|
await create_pgvector_db_and_tables()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
from cognee.context_global_variables import (
|
from cognee.context_global_variables import (
|
||||||
graph_db_config as context_graph_db_config,
|
graph_db_config as context_graph_db_config,
|
||||||
vector_db_config as context_vector_db_config,
|
vector_db_config as context_vector_db_config,
|
||||||
|
backend_access_control_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
from cognee.infrastructure.databases.relational import (
|
from cognee.infrastructure.databases.relational import (
|
||||||
|
|
@ -26,7 +27,8 @@ async def setup_and_check_environment(
|
||||||
|
|
||||||
# Create tables for databases
|
# Create tables for databases
|
||||||
await create_relational_db_and_tables()
|
await create_relational_db_and_tables()
|
||||||
await create_pgvector_db_and_tables()
|
if not backend_access_control_enabled():
|
||||||
|
await create_pgvector_db_and_tables()
|
||||||
|
|
||||||
global _first_run_done
|
global _first_run_done
|
||||||
async with _first_run_lock:
|
async with _first_run_lock:
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue