diff --git a/cognee/context_global_variables.py b/cognee/context_global_variables.py index d52de4b4e..1949e7345 100644 --- a/cognee/context_global_variables.py +++ b/cognee/context_global_variables.py @@ -61,7 +61,8 @@ async def set_database_global_context_variables(dataset: Union[str, UUID], user_ databases_directory_path, dataset_database.vector_database_name ), "vector_db_key": "", - "vector_db_provider": "lancedb", + "vector_db_provider": "pgvector", # CHANGED from "lancedb" + "vector_db_schema": dataset_database.vector_database_name, # NEW: PostgreSQL schema for pgvector isolation } graph_config = { diff --git a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py index 29156025d..99e58cc1e 100644 --- a/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py +++ b/cognee/infrastructure/databases/utils/get_or_create_dataset_database.py @@ -31,9 +31,12 @@ async def get_or_create_dataset_database( db_engine = get_relational_engine() dataset_id = await get_unique_dataset_id(dataset, user) + # CHANGED: Generate schema name instead of file name + # Postgres identifiers: lowercase, no dashes, max 63 chars + dataset_id_hex = str(dataset_id).replace("-", "") + vector_db_name = f"vec_{dataset_id_hex}" # e.g., "vec_9f305871db4f4dc6..." - vector_db_name = f"{dataset_id}.lance.db" - graph_db_name = f"{dataset_id}.pkl" + graph_db_name = f"graph_{dataset_id_hex}.pkl" async with db_engine.get_async_session() as session: # Create dataset if it doesn't exist diff --git a/cognee/infrastructure/databases/vector/config.py b/cognee/infrastructure/databases/vector/config.py index b6d3ae644..df59948cf 100644 --- a/cognee/infrastructure/databases/vector/config.py +++ b/cognee/infrastructure/databases/vector/config.py @@ -26,6 +26,7 @@ class VectorConfig(BaseSettings): vector_db_port: int = 1234 vector_db_key: str = "" vector_db_provider: str = "lancedb" + vector_db_schema: str = "" # NEW: PostgreSQL schema for pgvector isolation model_config = SettingsConfigDict(env_file=".env", extra="allow") @@ -60,6 +61,7 @@ class VectorConfig(BaseSettings): "vector_db_port": self.vector_db_port, "vector_db_key": self.vector_db_key, "vector_db_provider": self.vector_db_provider, + "vector_db_schema": self.vector_db_schema, # NEW: PostgreSQL schema for pgvector isolation } diff --git a/cognee/infrastructure/databases/vector/create_vector_engine.py b/cognee/infrastructure/databases/vector/create_vector_engine.py index d1cf855d7..4672fc2f8 100644 --- a/cognee/infrastructure/databases/vector/create_vector_engine.py +++ b/cognee/infrastructure/databases/vector/create_vector_engine.py @@ -10,6 +10,7 @@ def create_vector_engine( vector_db_url: str, vector_db_port: str = "", vector_db_key: str = "", + vector_db_schema: str = "", ): """ Create a vector database engine based on the specified provider. @@ -30,7 +31,8 @@ def create_vector_engine( - vector_db_key (str): The API key or access token for the vector database instance. - vector_db_provider (str): The name of the vector database provider to use (e.g., 'pgvector'). - + - vector_db_schema (str): The schema for the vector database instance. Required for + pgvector. Returns: -------- @@ -76,6 +78,7 @@ def create_vector_engine( connection_string, vector_db_key, embedding_engine, + schema_name=vector_db_schema, ) elif vector_db_provider.lower() == "chromadb": diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 1986fae48..dfb3a885f 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -3,7 +3,7 @@ from typing import List, Optional, get_type_hints from sqlalchemy.inspection import inspect from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.dialects.postgresql import insert -from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func +from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func, text from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.exc import ProgrammingError from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential @@ -50,11 +50,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): connection_string: str, api_key: Optional[str], embedding_engine: EmbeddingEngine, + schema_name: str = "", ): self.api_key = api_key self.embedding_engine = embedding_engine self.db_uri: str = connection_string self.VECTOR_DB_LOCK = asyncio.Lock() + # Schema for project isolation; defaults to "public" if not specified + self.schema_name = schema_name if schema_name else "public" + self._schema_created = False # Track if we've already created the schema relational_db = get_relational_engine() @@ -73,6 +77,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.Vector = Vector + async def _ensure_schema_exists(self): + """Create the PostgreSQL schema if it doesn't exist (for project isolation).""" + if self._schema_created or self.schema_name == "public": + return + + async with self.engine.begin() as connection: + # Use quoted identifier to handle any special characters + await connection.execute( + text(f'CREATE SCHEMA IF NOT EXISTS "{self.schema_name}"') + ) + self._schema_created = True + logger.info(f"Ensured schema '{self.schema_name}' exists for project isolation") + async def embed_data(self, data: list[str]) -> list[list[float]]: """ Embed a list of texts into vectors using the specified embedding engine. @@ -91,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def has_collection(self, collection_name: str) -> bool: """ - Check if a specified collection exists in the database. + Check if a specified collection exists in the database schema. Parameters: ----------- @@ -101,18 +118,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): Returns: -------- - - bool: Returns True if the collection exists, False otherwise. + - bool: Returns True if the collection exists in the schema, False otherwise. """ async with self.engine.begin() as connection: # Create a MetaData instance to load table information metadata = MetaData() - # Load table information from schema into MetaData - await connection.run_sync(metadata.reflect) + # Load table information from our specific schema into MetaData + await connection.run_sync( + lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name) + ) - if collection_name in metadata.tables: - return True - else: - return False + # Tables are keyed as "schema.table_name" when schema is specified + full_table_name = f"{self.schema_name}.{collection_name}" + return full_table_name in metadata.tables @retry( retry=retry_if_exception_type( @@ -124,6 +142,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def create_collection(self, collection_name: str, payload_schema=None): data_point_types = get_type_hints(DataPoint) vector_size = self.embedding_engine.get_vector_size() + schema_name = self.schema_name # Capture for use in class definition + + # Ensure the schema exists before creating tables + await self._ensure_schema_exists() if not await self.has_collection(collection_name): async with self.VECTOR_DB_LOCK: @@ -145,7 +167,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): """ __tablename__ = collection_name - __table_args__ = {"extend_existing": True} + __table_args__ = {"extend_existing": True, "schema": schema_name} # PGVector requires one column to be the primary key id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) payload = Column(JSON) @@ -161,6 +183,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await connection.run_sync( Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] ) + logger.debug(f"Created collection '{collection_name}' in schema '{schema_name}'") @retry( retry=retry_if_exception_type(DeadlockDetectedError), @@ -181,6 +204,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) vector_size = self.embedding_engine.get_vector_size() + schema_name = self.schema_name # Capture for use in class definition class PGVectorDataPoint(Base): """ @@ -194,7 +218,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): """ __tablename__ = collection_name - __table_args__ = {"extend_existing": True} + __table_args__ = {"extend_existing": True, "schema": schema_name} # PGVector requires one column to be the primary key id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) payload = Column(JSON) @@ -265,18 +289,23 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def get_table(self, collection_name: str) -> Table: """ Dynamically loads a table using the given collection name - with an async engine. + from the configured schema with an async engine. """ async with self.engine.begin() as connection: # Create a MetaData instance to load table information metadata = MetaData() - # Load table information from schema into MetaData - await connection.run_sync(metadata.reflect) - if collection_name in metadata.tables: - return metadata.tables[collection_name] + # Load table information from our specific schema into MetaData + await connection.run_sync( + lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name) + ) + + # Tables are keyed as "schema.table_name" when schema is specified + full_table_name = f"{self.schema_name}.{collection_name}" + if full_table_name in metadata.tables: + return metadata.tables[full_table_name] else: raise CollectionNotFoundError( - f"Collection '{collection_name}' not found!", + f"Collection '{collection_name}' not found in schema '{self.schema_name}'!", ) async def retrieve(self, collection_name: str, data_point_ids: List[str]): @@ -394,6 +423,28 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await session.commit() return results + async def drop_schema(self): + """ + Drop the entire PostgreSQL schema and all its tables. + This is useful for cleanup/reset operations. + Only drops non-public schemas to prevent accidental data loss. + """ + if self.schema_name == "public": + logger.warning("Refusing to drop public schema - use delete_database() instead") + return + + async with self.engine.begin() as connection: + await connection.execute( + text(f'DROP SCHEMA IF EXISTS "{self.schema_name}" CASCADE') + ) + self._schema_created = False + logger.info(f"Dropped schema '{self.schema_name}' and all its contents") + async def prune(self): - # Clean up the database if it was set up as temporary - await self.delete_database() + """Clean up the database/schema.""" + if self.schema_name != "public": + # For project-specific schemas, drop the entire schema + await self.drop_schema() + else: + # For public schema, just delete the database (existing behavior) + await self.delete_database() diff --git a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py index c2477086d..32bf45a9e 100644 --- a/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py +++ b/cognee/infrastructure/databases/vector/pgvector/create_db_and_tables.py @@ -10,3 +10,10 @@ async def create_db_and_tables(): if vector_config["vector_db_provider"] == "pgvector": async with vector_engine.engine.begin() as connection: await connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + + # NEW: Create schema if specified in config + schema_name = vector_config.get("vector_db_schema", "") + if schema_name and schema_name != "public": + await connection.execute( + text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"') + )