pgvector expansion for dataset isolation

This commit is contained in:
Vaggelis Stamkopoulos 2025-12-15 15:39:30 +02:00
parent 75fea8dcc8
commit 11a8dcce4f
6 changed files with 90 additions and 23 deletions

View file

@ -61,7 +61,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
databases_directory_path, dataset_database.vector_database_name databases_directory_path, dataset_database.vector_database_name
), ),
"vector_db_key": "", "vector_db_key": "",
"vector_db_provider": "lancedb", "vector_db_provider": "pgvector", # CHANGED from "lancedb"
"vector_db_schema": dataset_database.vector_database_name, # NEW: PostgreSQL schema for pgvector isolation
} }
graph_config = { graph_config = {

View file

@ -31,9 +31,12 @@ async def get_or_create_dataset_database(
db_engine = get_relational_engine() db_engine = get_relational_engine()
dataset_id = await get_unique_dataset_id(dataset, user) dataset_id = await get_unique_dataset_id(dataset, user)
# CHANGED: Generate schema name instead of file name
# Postgres identifiers: lowercase, no dashes, max 63 chars
dataset_id_hex = str(dataset_id).replace("-", "")
vector_db_name = f"vec_{dataset_id_hex}" # e.g., "vec_9f305871db4f4dc6..."
vector_db_name = f"{dataset_id}.lance.db" graph_db_name = f"graph_{dataset_id_hex}.pkl"
graph_db_name = f"{dataset_id}.pkl"
async with db_engine.get_async_session() as session: async with db_engine.get_async_session() as session:
# Create dataset if it doesn't exist # Create dataset if it doesn't exist

View file

@ -26,6 +26,7 @@ class VectorConfig(BaseSettings):
vector_db_port: int = 1234 vector_db_port: int = 1234
vector_db_key: str = "" vector_db_key: str = ""
vector_db_provider: str = "lancedb" vector_db_provider: str = "lancedb"
vector_db_schema: str = "" # NEW: PostgreSQL schema for pgvector isolation
model_config = SettingsConfigDict(env_file=".env", extra="allow") model_config = SettingsConfigDict(env_file=".env", extra="allow")
@ -60,6 +61,7 @@ class VectorConfig(BaseSettings):
"vector_db_port": self.vector_db_port, "vector_db_port": self.vector_db_port,
"vector_db_key": self.vector_db_key, "vector_db_key": self.vector_db_key,
"vector_db_provider": self.vector_db_provider, "vector_db_provider": self.vector_db_provider,
"vector_db_schema": self.vector_db_schema, # NEW: PostgreSQL schema for pgvector isolation
} }

View file

@ -10,6 +10,7 @@ def create_vector_engine(
vector_db_url: str, vector_db_url: str,
vector_db_port: str = "", vector_db_port: str = "",
vector_db_key: str = "", vector_db_key: str = "",
vector_db_schema: str = "",
): ):
""" """
Create a vector database engine based on the specified provider. Create a vector database engine based on the specified provider.
@ -30,7 +31,8 @@ def create_vector_engine(
- vector_db_key (str): The API key or access token for the vector database instance. - vector_db_key (str): The API key or access token for the vector database instance.
- vector_db_provider (str): The name of the vector database provider to use (e.g., - vector_db_provider (str): The name of the vector database provider to use (e.g.,
'pgvector'). 'pgvector').
- vector_db_schema (str): The schema for the vector database instance. Required for
pgvector.
Returns: Returns:
-------- --------
@ -76,6 +78,7 @@ def create_vector_engine(
connection_string, connection_string,
vector_db_key, vector_db_key,
embedding_engine, embedding_engine,
schema_name=vector_db_schema,
) )
elif vector_db_provider.lower() == "chromadb": elif vector_db_provider.lower() == "chromadb":

View file

@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func, text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@ -50,11 +50,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
connection_string: str, connection_string: str,
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine, embedding_engine: EmbeddingEngine,
schema_name: str = "",
): ):
self.api_key = api_key self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.db_uri: str = connection_string self.db_uri: str = connection_string
self.VECTOR_DB_LOCK = asyncio.Lock() self.VECTOR_DB_LOCK = asyncio.Lock()
# Schema for project isolation; defaults to "public" if not specified
self.schema_name = schema_name if schema_name else "public"
self._schema_created = False # Track if we've already created the schema
relational_db = get_relational_engine() relational_db = get_relational_engine()
@ -73,6 +77,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.Vector = Vector self.Vector = Vector
async def _ensure_schema_exists(self):
"""Create the PostgreSQL schema if it doesn't exist (for project isolation)."""
if self._schema_created or self.schema_name == "public":
return
async with self.engine.begin() as connection:
# Use quoted identifier to handle any special characters
await connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{self.schema_name}"')
)
self._schema_created = True
logger.info(f"Ensured schema '{self.schema_name}' exists for project isolation")
async def embed_data(self, data: list[str]) -> list[list[float]]: async def embed_data(self, data: list[str]) -> list[list[float]]:
""" """
Embed a list of texts into vectors using the specified embedding engine. Embed a list of texts into vectors using the specified embedding engine.
@ -91,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
""" """
Check if a specified collection exists in the database. Check if a specified collection exists in the database schema.
Parameters: Parameters:
----------- -----------
@ -101,18 +118,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
Returns: Returns:
-------- --------
- bool: Returns True if the collection exists, False otherwise. - bool: Returns True if the collection exists in the schema, False otherwise.
""" """
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# Create a MetaData instance to load table information # Create a MetaData instance to load table information
metadata = MetaData() metadata = MetaData()
# Load table information from schema into MetaData # Load table information from our specific schema into MetaData
await connection.run_sync(metadata.reflect) await connection.run_sync(
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
)
if collection_name in metadata.tables: # Tables are keyed as "schema.table_name" when schema is specified
return True full_table_name = f"{self.schema_name}.{collection_name}"
else: return full_table_name in metadata.tables
return False
@retry( @retry(
retry=retry_if_exception_type( retry=retry_if_exception_type(
@ -124,6 +142,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
# Ensure the schema exists before creating tables
await self._ensure_schema_exists()
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK: async with self.VECTOR_DB_LOCK:
@ -145,7 +167,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
""" """
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON) payload = Column(JSON)
@ -161,6 +183,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await connection.run_sync( await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
) )
logger.debug(f"Created collection '{collection_name}' in schema '{schema_name}'")
@retry( @retry(
retry=retry_if_exception_type(DeadlockDetectedError), retry=retry_if_exception_type(DeadlockDetectedError),
@ -181,6 +204,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) )
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
""" """
@ -194,7 +218,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
""" """
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON) payload = Column(JSON)
@ -265,18 +289,23 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def get_table(self, collection_name: str) -> Table: async def get_table(self, collection_name: str) -> Table:
""" """
Dynamically loads a table using the given collection name Dynamically loads a table using the given collection name
with an async engine. from the configured schema with an async engine.
""" """
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# Create a MetaData instance to load table information # Create a MetaData instance to load table information
metadata = MetaData() metadata = MetaData()
# Load table information from schema into MetaData # Load table information from our specific schema into MetaData
await connection.run_sync(metadata.reflect) await connection.run_sync(
if collection_name in metadata.tables: lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
return metadata.tables[collection_name] )
# Tables are keyed as "schema.table_name" when schema is specified
full_table_name = f"{self.schema_name}.{collection_name}"
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
else: else:
raise CollectionNotFoundError( raise CollectionNotFoundError(
f"Collection '{collection_name}' not found!", f"Collection '{collection_name}' not found in schema '{self.schema_name}'!",
) )
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
@ -394,6 +423,28 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await session.commit() await session.commit()
return results return results
async def drop_schema(self):
"""
Drop the entire PostgreSQL schema and all its tables.
This is useful for cleanup/reset operations.
Only drops non-public schemas to prevent accidental data loss.
"""
if self.schema_name == "public":
logger.warning("Refusing to drop public schema - use delete_database() instead")
return
async with self.engine.begin() as connection:
await connection.execute(
text(f'DROP SCHEMA IF EXISTS "{self.schema_name}" CASCADE')
)
self._schema_created = False
logger.info(f"Dropped schema '{self.schema_name}' and all its contents")
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary """Clean up the database/schema."""
await self.delete_database() if self.schema_name != "public":
# For project-specific schemas, drop the entire schema
await self.drop_schema()
else:
# For public schema, just delete the database (existing behavior)
await self.delete_database()

View file

@ -10,3 +10,10 @@ async def create_db_and_tables():
if vector_config["vector_db_provider"] == "pgvector": if vector_config["vector_db_provider"] == "pgvector":
async with vector_engine.engine.begin() as connection: async with vector_engine.engine.begin() as connection:
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
# NEW: Create schema if specified in config
schema_name = vector_config.get("vector_db_schema", "")
if schema_name and schema_name != "public":
await connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')
)