pgvector expansion for dataset isolation
This commit is contained in:
parent
75fea8dcc8
commit
11a8dcce4f
6 changed files with 90 additions and 23 deletions
|
|
@ -61,7 +61,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_
|
||||||
databases_directory_path, dataset_database.vector_database_name
|
databases_directory_path, dataset_database.vector_database_name
|
||||||
),
|
),
|
||||||
"vector_db_key": "",
|
"vector_db_key": "",
|
||||||
"vector_db_provider": "lancedb",
|
"vector_db_provider": "pgvector", # CHANGED from "lancedb"
|
||||||
|
"vector_db_schema": dataset_database.vector_database_name, # NEW: PostgreSQL schema for pgvector isolation
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_config = {
|
graph_config = {
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,12 @@ async def get_or_create_dataset_database(
|
||||||
db_engine = get_relational_engine()
|
db_engine = get_relational_engine()
|
||||||
|
|
||||||
dataset_id = await get_unique_dataset_id(dataset, user)
|
dataset_id = await get_unique_dataset_id(dataset, user)
|
||||||
|
# CHANGED: Generate schema name instead of file name
|
||||||
|
# Postgres identifiers: lowercase, no dashes, max 63 chars
|
||||||
|
dataset_id_hex = str(dataset_id).replace("-", "")
|
||||||
|
vector_db_name = f"vec_{dataset_id_hex}" # e.g., "vec_9f305871db4f4dc6..."
|
||||||
|
|
||||||
vector_db_name = f"{dataset_id}.lance.db"
|
graph_db_name = f"graph_{dataset_id_hex}.pkl"
|
||||||
graph_db_name = f"{dataset_id}.pkl"
|
|
||||||
|
|
||||||
async with db_engine.get_async_session() as session:
|
async with db_engine.get_async_session() as session:
|
||||||
# Create dataset if it doesn't exist
|
# Create dataset if it doesn't exist
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class VectorConfig(BaseSettings):
|
||||||
vector_db_port: int = 1234
|
vector_db_port: int = 1234
|
||||||
vector_db_key: str = ""
|
vector_db_key: str = ""
|
||||||
vector_db_provider: str = "lancedb"
|
vector_db_provider: str = "lancedb"
|
||||||
|
vector_db_schema: str = "" # NEW: PostgreSQL schema for pgvector isolation
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
model_config = SettingsConfigDict(env_file=".env", extra="allow")
|
||||||
|
|
||||||
|
|
@ -60,6 +61,7 @@ class VectorConfig(BaseSettings):
|
||||||
"vector_db_port": self.vector_db_port,
|
"vector_db_port": self.vector_db_port,
|
||||||
"vector_db_key": self.vector_db_key,
|
"vector_db_key": self.vector_db_key,
|
||||||
"vector_db_provider": self.vector_db_provider,
|
"vector_db_provider": self.vector_db_provider,
|
||||||
|
"vector_db_schema": self.vector_db_schema, # NEW: PostgreSQL schema for pgvector isolation
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ def create_vector_engine(
|
||||||
vector_db_url: str,
|
vector_db_url: str,
|
||||||
vector_db_port: str = "",
|
vector_db_port: str = "",
|
||||||
vector_db_key: str = "",
|
vector_db_key: str = "",
|
||||||
|
vector_db_schema: str = "",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a vector database engine based on the specified provider.
|
Create a vector database engine based on the specified provider.
|
||||||
|
|
@ -30,7 +31,8 @@ def create_vector_engine(
|
||||||
- vector_db_key (str): The API key or access token for the vector database instance.
|
- vector_db_key (str): The API key or access token for the vector database instance.
|
||||||
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
|
- vector_db_provider (str): The name of the vector database provider to use (e.g.,
|
||||||
'pgvector').
|
'pgvector').
|
||||||
|
- vector_db_schema (str): The schema for the vector database instance. Required for
|
||||||
|
pgvector.
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
||||||
|
|
@ -76,6 +78,7 @@ def create_vector_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
vector_db_key,
|
vector_db_key,
|
||||||
embedding_engine,
|
embedding_engine,
|
||||||
|
schema_name=vector_db_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif vector_db_provider.lower() == "chromadb":
|
elif vector_db_provider.lower() == "chromadb":
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from sqlalchemy.dialects.postgresql import insert
|
from sqlalchemy.dialects.postgresql import insert
|
||||||
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
|
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func, text
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy.exc import ProgrammingError
|
from sqlalchemy.exc import ProgrammingError
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||||
|
|
@ -50,11 +50,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
connection_string: str,
|
connection_string: str,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine,
|
embedding_engine: EmbeddingEngine,
|
||||||
|
schema_name: str = "",
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
self.db_uri: str = connection_string
|
self.db_uri: str = connection_string
|
||||||
self.VECTOR_DB_LOCK = asyncio.Lock()
|
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||||
|
# Schema for project isolation; defaults to "public" if not specified
|
||||||
|
self.schema_name = schema_name if schema_name else "public"
|
||||||
|
self._schema_created = False # Track if we've already created the schema
|
||||||
|
|
||||||
relational_db = get_relational_engine()
|
relational_db = get_relational_engine()
|
||||||
|
|
||||||
|
|
@ -73,6 +77,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
self.Vector = Vector
|
self.Vector = Vector
|
||||||
|
|
||||||
|
async def _ensure_schema_exists(self):
|
||||||
|
"""Create the PostgreSQL schema if it doesn't exist (for project isolation)."""
|
||||||
|
if self._schema_created or self.schema_name == "public":
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.engine.begin() as connection:
|
||||||
|
# Use quoted identifier to handle any special characters
|
||||||
|
await connection.execute(
|
||||||
|
text(f'CREATE SCHEMA IF NOT EXISTS "{self.schema_name}"')
|
||||||
|
)
|
||||||
|
self._schema_created = True
|
||||||
|
logger.info(f"Ensured schema '{self.schema_name}' exists for project isolation")
|
||||||
|
|
||||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Embed a list of texts into vectors using the specified embedding engine.
|
Embed a list of texts into vectors using the specified embedding engine.
|
||||||
|
|
@ -91,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
async def has_collection(self, collection_name: str) -> bool:
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a specified collection exists in the database.
|
Check if a specified collection exists in the database schema.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
-----------
|
-----------
|
||||||
|
|
@ -101,18 +118,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
Returns:
|
Returns:
|
||||||
--------
|
--------
|
||||||
|
|
||||||
- bool: Returns True if the collection exists, False otherwise.
|
- bool: Returns True if the collection exists in the schema, False otherwise.
|
||||||
"""
|
"""
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
# Create a MetaData instance to load table information
|
# Create a MetaData instance to load table information
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
# Load table information from schema into MetaData
|
# Load table information from our specific schema into MetaData
|
||||||
await connection.run_sync(metadata.reflect)
|
await connection.run_sync(
|
||||||
|
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
|
||||||
|
)
|
||||||
|
|
||||||
if collection_name in metadata.tables:
|
# Tables are keyed as "schema.table_name" when schema is specified
|
||||||
return True
|
full_table_name = f"{self.schema_name}.{collection_name}"
|
||||||
else:
|
return full_table_name in metadata.tables
|
||||||
return False
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(
|
retry=retry_if_exception_type(
|
||||||
|
|
@ -124,6 +142,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
schema_name = self.schema_name # Capture for use in class definition
|
||||||
|
|
||||||
|
# Ensure the schema exists before creating tables
|
||||||
|
await self._ensure_schema_exists()
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
async with self.VECTOR_DB_LOCK:
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
|
@ -145,7 +167,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True, "schema": schema_name}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
|
|
@ -161,6 +183,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
await connection.run_sync(
|
await connection.run_sync(
|
||||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||||
)
|
)
|
||||||
|
logger.debug(f"Created collection '{collection_name}' in schema '{schema_name}'")
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
|
|
@ -181,6 +204,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
schema_name = self.schema_name # Capture for use in class definition
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
class PGVectorDataPoint(Base):
|
||||||
"""
|
"""
|
||||||
|
|
@ -194,7 +218,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True, "schema": schema_name}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
|
|
@ -265,18 +289,23 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
async def get_table(self, collection_name: str) -> Table:
|
async def get_table(self, collection_name: str) -> Table:
|
||||||
"""
|
"""
|
||||||
Dynamically loads a table using the given collection name
|
Dynamically loads a table using the given collection name
|
||||||
with an async engine.
|
from the configured schema with an async engine.
|
||||||
"""
|
"""
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
# Create a MetaData instance to load table information
|
# Create a MetaData instance to load table information
|
||||||
metadata = MetaData()
|
metadata = MetaData()
|
||||||
# Load table information from schema into MetaData
|
# Load table information from our specific schema into MetaData
|
||||||
await connection.run_sync(metadata.reflect)
|
await connection.run_sync(
|
||||||
if collection_name in metadata.tables:
|
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
|
||||||
return metadata.tables[collection_name]
|
)
|
||||||
|
|
||||||
|
# Tables are keyed as "schema.table_name" when schema is specified
|
||||||
|
full_table_name = f"{self.schema_name}.{collection_name}"
|
||||||
|
if full_table_name in metadata.tables:
|
||||||
|
return metadata.tables[full_table_name]
|
||||||
else:
|
else:
|
||||||
raise CollectionNotFoundError(
|
raise CollectionNotFoundError(
|
||||||
f"Collection '{collection_name}' not found!",
|
f"Collection '{collection_name}' not found in schema '{self.schema_name}'!",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
|
|
@ -394,6 +423,28 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def drop_schema(self):
|
||||||
|
"""
|
||||||
|
Drop the entire PostgreSQL schema and all its tables.
|
||||||
|
This is useful for cleanup/reset operations.
|
||||||
|
Only drops non-public schemas to prevent accidental data loss.
|
||||||
|
"""
|
||||||
|
if self.schema_name == "public":
|
||||||
|
logger.warning("Refusing to drop public schema - use delete_database() instead")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.engine.begin() as connection:
|
||||||
|
await connection.execute(
|
||||||
|
text(f'DROP SCHEMA IF EXISTS "{self.schema_name}" CASCADE')
|
||||||
|
)
|
||||||
|
self._schema_created = False
|
||||||
|
logger.info(f"Dropped schema '{self.schema_name}' and all its contents")
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
"""Clean up the database/schema."""
|
||||||
await self.delete_database()
|
if self.schema_name != "public":
|
||||||
|
# For project-specific schemas, drop the entire schema
|
||||||
|
await self.drop_schema()
|
||||||
|
else:
|
||||||
|
# For public schema, just delete the database (existing behavior)
|
||||||
|
await self.delete_database()
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,10 @@ async def create_db_and_tables():
|
||||||
if vector_config["vector_db_provider"] == "pgvector":
|
if vector_config["vector_db_provider"] == "pgvector":
|
||||||
async with vector_engine.engine.begin() as connection:
|
async with vector_engine.engine.begin() as connection:
|
||||||
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
|
||||||
|
|
||||||
|
# NEW: Create schema if specified in config
|
||||||
|
schema_name = vector_config.get("vector_db_schema", "")
|
||||||
|
if schema_name and schema_name != "public":
|
||||||
|
await connection.execute(
|
||||||
|
text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue