cognee/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py
2025-12-15 15:39:30 +02:00

450 lines
18 KiB
Python

import asyncio
from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func, text
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..utils import normalize_distances
from ..models.ScoredResult import ScoredResult
from ..exceptions import CollectionNotFoundError
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from .serialize_data import serialize_data
logger = get_logger("PGVectorAdapter")
class IndexSchema(DataPoint):
"""
Define a schema for indexing data points with a text field.
This class inherits from the DataPoint class and specifies the structure of a single
data point that includes a text attribute. It also includes a metadata field that
indicates which fields should be indexed.
"""
text: str
metadata: dict = {"index_fields": ["text"]}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
schema_name: str = "",
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
self.VECTOR_DB_LOCK = asyncio.Lock()
# Schema for project isolation; defaults to "public" if not specified
self.schema_name = schema_name if schema_name else "public"
self._schema_created = False # Track if we've already created the schema
relational_db = get_relational_engine()
# If postgreSQL is used we must use the same engine and sessionmaker
if relational_db.engine.dialect.name == "postgresql":
self.engine = relational_db.engine
self.sessionmaker = relational_db.sessionmaker
else:
# If not create new instances of engine and sessionmaker
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
# Has to be imported at class level
# Functions reading tables from database need to know what a Vector column type is
from pgvector.sqlalchemy import Vector
self.Vector = Vector
async def _ensure_schema_exists(self):
"""Create the PostgreSQL schema if it doesn't exist (for project isolation)."""
if self._schema_created or self.schema_name == "public":
return
async with self.engine.begin() as connection:
# Use quoted identifier to handle any special characters
await connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{self.schema_name}"')
)
self._schema_created = True
logger.info(f"Ensured schema '{self.schema_name}' exists for project isolation")
async def embed_data(self, data: list[str]) -> list[list[float]]:
"""
Embed a list of texts into vectors using the specified embedding engine.
Parameters:
-----------
- data (list[str]): A list of strings to be embedded into vectors.
Returns:
--------
- list[list[float]]: A list of lists of floats representing embedded vectors.
"""
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a specified collection exists in the database schema.
Parameters:
-----------
- collection_name (str): The name of the collection to check for existence.
Returns:
--------
- bool: Returns True if the collection exists in the schema, False otherwise.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from our specific schema into MetaData
await connection.run_sync(
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
)
# Tables are keyed as "schema.table_name" when schema is specified
full_table_name = f"{self.schema_name}.{collection_name}"
return full_table_name in metadata.tables
@retry(
retry=retry_if_exception_type(
(DuplicateTableError, UniqueViolationError, ProgrammingError)
),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
# Ensure the schema exists before creating tables
await self._ensure_schema_exists()
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
logger.debug(f"Created collection '{collection_name}' in schema '{schema_name}'")
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
schema_name = self.schema_name # Capture for use in class definition
class PGVectorDataPoint(Base):
"""
Represents a data point in a PGVector database. This class maps to a table defined by
the SQLAlchemy ORM.
It contains the following public instance variables:
- id: An identifier for the data point.
- payload: A JSON object containing additional data related to the data point.
- vector: A vector representation of the data point, configured to the specified size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True, "schema": schema_name}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.get_async_session() as session:
pgvector_data_points = []
for data_index, data_point in enumerate(data_points):
# Check to see if data should be updated or a new data item should be created
# data_point_db = (
# await session.execute(
# select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
# )
# ).scalar_one_or_none()
# If data point exists update it, if not create a new one
# if data_point_db:
# data_point_db.id = data_point.id
# data_point_db.vector = data_vectors[data_index]
# data_point_db.payload = serialize_data(data_point.model_dump())
# pgvector_data_points.append(data_point_db)
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
return {
column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs
}
# session.add_all(pgvector_data_points)
insert_statement = insert(PGVectorDataPoint).values(
[to_dict(data_point) for data_point in pgvector_data_points]
)
insert_statement = insert_statement.on_conflict_do_nothing(index_elements=["id"])
await session.execute(insert_statement)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=DataPoint.get_embeddable_data(data_point),
)
for data_point in data_points
],
)
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
from the configured schema with an async engine.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from our specific schema into MetaData
await connection.run_sync(
lambda sync_conn: metadata.reflect(sync_conn, schema=self.schema_name)
)
# Tables are keyed as "schema.table_name" when schema is specified
full_table_name = f"{self.schema_name}.{collection_name}"
if full_table_name in metadata.tables:
return metadata.tables[full_table_name]
else:
raise CollectionNotFoundError(
f"Collection '{collection_name}' not found in schema '{self.schema_name}'!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
for result in results
]
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
if limit is None:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.scalar_one()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
for row in vector_list
]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def drop_schema(self):
"""
Drop the entire PostgreSQL schema and all its tables.
This is useful for cleanup/reset operations.
Only drops non-public schemas to prevent accidental data loss.
"""
if self.schema_name == "public":
logger.warning("Refusing to drop public schema - use delete_database() instead")
return
async with self.engine.begin() as connection:
await connection.execute(
text(f'DROP SCHEMA IF EXISTS "{self.schema_name}" CASCADE')
)
self._schema_created = False
logger.info(f"Dropped schema '{self.schema_name}' and all its contents")
async def prune(self):
"""Clean up the database/schema."""
if self.schema_name != "public":
# For project-specific schemas, drop the entire schema
await self.drop_schema()
else:
# For public schema, just delete the database (existing behavior)
await self.delete_database()