pgvector expansion for dataset isolation

This commit is contained in:
Vaggelis Stamkopoulos 2025-12-15 15:39:30 +02:00
parent 75fea8dcc8
commit 11a8dcce4f
6 changed files with 90 additions and 23 deletions

View file

@ -61,7 +61,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
databases_directory_path, dataset_database.vector_database_name
),
"vector_db_key": "",
"vector_db_provider": "lancedb",
"vector_db_provider": "pgvector", # CHANGED from "lancedb"
"vector_db_schema": dataset_database.vector_database_name, # NEW: PostgreSQL schema for pgvector isolation
}
graph_config = {

View file

@ -31,9 +31,12 @@ async def get_or_create_dataset_database(
db_engine = get_relational_engine()
dataset_id = await get_unique_dataset_id(dataset, user)
# CHANGED: Generate schema name instead of file name
# Postgres identifiers: lowercase, no dashes, max 63 chars
dataset_id_hex = str(dataset_id).replace("-", "")
vector_db_name = f"vec_{dataset_id_hex}" # e.g., "vec_9f305871db4f4dc6..."
vector_db_name = f"{dataset_id}.lance.db"
graph_db_name = f"{dataset_id}.pkl"
graph_db_name = f"graph_{dataset_id_hex}.pkl"
async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist

View file

@ -26,6 +26,7 @@ class VectorConfig(BaseSettings):
vector_db_port: int = 1234
vector_db_key: str = ""
vector_db_provider: str = "lancedb"
vector_db_schema: str = "" # NEW: PostgreSQL schema for pgvector isolation
model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -60,6 +61,7 @@ class VectorConfig(BaseSettings):
"vector_db_port": self.vector_db_port,
"vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider,
"vector_db_schema": self.vector_db_schema, # NEW: PostgreSQL schema for pgvector isolation
}

View file

@ -10,6 +10,7 @@ def create_vector_engine(
vector_db_url: str,
vector_db_port: str = "",
vector_db_key: str = "",
vector_db_schema: str = "",
):
"""
Create a vector database engine based on the specified provider.
@ -30,7 +31,8 @@ def create_vector_engine(
- vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector').
- vector_db_schema (str): The schema for the vector database instance. Required for
pgvector.
Returns:
--------
@ -76,6 +78,7 @@ def create_vector_engine(
connection_string,
vector_db_key,
embedding_engine,
schema_name=vector_db_schema,
)
elif vector_db_provider.lower() == "chromadb":

View file

@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func, text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -50,11 +50,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
schema_name: str = "",
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
self.VECTOR_DB_LOCK = asyncio.Lock()
# Schema for project isolation; defaults to "public" if not specified
self.schema_name = schema_name if schema_name else "public"
self._schema_created = False # Track if we've already created the schema
relational_db = get_relational_engine()
@ -73,6 +77,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.Vector = Vector
async def _ensure_schema_exists(self):
"""Create the PostgreSQL schema if it doesn't exist (for project isolation)."""
if self._schema_created or self.schema_name == "public":
return
async with self.engine.begin() as connection:
# Use quoted identifier to handle any special characters
await connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{self.schema_name}"')
)
self._schema_created = True
logger.info(f"Ensured schema '{self.schema_name}' exists for project isolation")
async def embed_data(self, data: list[str]) -> list[list[float]]:
"""
Embed a list of texts into vectors using the specified embedding engine.
@ -91,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a specified collection exists in the database.
Check if a specified collection exists in the database schema.
Parameters:
-----------
@ -101,18 +118,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
Returns:
--------
- bool: Returns True if the collection exists, False otherwise.
- bool: Returns True if the collection exists in the schema, False otherwise.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
# Load table information from our specific schema into MetaData
await connection.run_sync(
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
)
if collection_name in metadata.tables:
return True
else:
return False
# Tables are keyed as "schema.table_name" when schema is specified
full_table_name = f"{self.schema_name}.{collection_name}"
return full_table_name in metadata.tables
@retry(
retry=retry_if_exception_type(
@ -124,6 +142,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
# Ensure the schema exists before creating tables
await self._ensure_schema_exists()
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
@ -145,7 +167,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
__table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
@ -161,6 +183,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
logger.debug(f"Created collection '{collection_name}' in schema '{schema_name}'")
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
@ -181,6 +204,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
)
vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
class PGVectorDataPoint(Base):
"""
@ -194,7 +218,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
__table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
@ -265,18 +289,23 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
from the configured schema with an async engine.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return metadata.tables[collection_name]
# Load table information from our specific schema into MetaData
await connection.run_sync(
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
)
# Tables are keyed as "schema.table_name" when schema is specified
full_table_name = f"{self.schema_name}.{collection_name}"
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
else:
raise CollectionNotFoundError(
f"Collection '{collection_name}' not found!",
f"Collection '{collection_name}' not found in schema '{self.schema_name}'!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
@ -394,6 +423,28 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit()
return results
async def drop_schema(self):
"""
Drop the entire PostgreSQL schema and all its tables.
This is useful for cleanup/reset operations.
Only drops non-public schemas to prevent accidental data loss.
"""
if self.schema_name == "public":
logger.warning("Refusing to drop public schema - use delete_database() instead")
return
async with self.engine.begin() as connection:
await connection.execute(
text(f'DROP SCHEMA IF EXISTS "{self.schema_name}" CASCADE')
)
self._schema_created = False
logger.info(f"Dropped schema '{self.schema_name}' and all its contents")
async def prune(self):
# Clean up the database if it was set up as temporary
await self.delete_database()
"""Clean up the database/schema."""
if self.schema_name != "public":
# For project-specific schemas, drop the entire schema
await self.drop_schema()
else:
# For public schema, just delete the database (existing behavior)
await self.delete_database()

View file

@ -10,3 +10,10 @@ async def create_db_and_tables():
if vector_config["vector_db_provider"] == "pgvector":
async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# NEW: Create schema if specified in config
schema_name = vector_config.get("vector_db_schema", "")
if schema_name and schema_name != "public":
await connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')
)