cognee/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py

412 lines
16 KiB
Python

import asyncio
from typing import List, Optional, get_type_hints
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import JSON, Column, Table, select, delete, MetaData, func
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.exc import ProgrammingError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from asyncpg import DeadlockDetectedError, DuplicateTableError, UniqueViolationError
from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
from distributed.utils import override_distributed
from distributed.tasks.queued_add_data_points import queued_add_data_points
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..utils import normalize_distances
from ..models.ScoredResult import ScoredResult
from ..exceptions import CollectionNotFoundError
from ..vector_db_interface import VectorDBInterface
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from .serialize_data import serialize_data
logger = get_logger("PGVectorAdapter")
class IndexSchema(DataPoint):
"""
Define a schema for indexing data points with a text field.
This class inherits from the DataPoint class and specifies the structure of a single
data point that includes a text attribute. It also includes a metadata field that
indicates which fields should be indexed.
"""
text: str
metadata: dict = {"index_fields": ["text"]}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
self.VECTOR_DB_LOCK = asyncio.Lock()
relational_db = get_relational_engine()
# If postgreSQL is used we must use the same engine and sessionmaker
if relational_db.engine.dialect.name == "postgresql":
self.engine = relational_db.engine
self.sessionmaker = relational_db.sessionmaker
else:
# If not create new instances of engine and sessionmaker
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
# Has to be imported at class level
# Functions reading tables from database need to know what a Vector column type is
from pgvector.sqlalchemy import Vector
self.Vector = Vector
async def embed_data(self, data: list[str]) -> list[list[float]]:
"""
Embed a list of texts into vectors using the specified embedding engine.
Parameters:
-----------
- data (list[str]): A list of strings to be embedded into vectors.
Returns:
--------
- list[list[float]]: A list of lists of floats representing embedded vectors.
"""
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
"""
Check if a specified collection exists in the database.
Parameters:
-----------
- collection_name (str): The name of the collection to check for existence.
Returns:
--------
- bool: Returns True if the collection exists, False otherwise.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return True
else:
return False
@retry(
retry=retry_if_exception_type(
(DuplicateTableError, UniqueViolationError, ProgrammingError)
),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
"""
Represent a point in a vector data space with associated data and vector representation.
This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables:
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
Instance variables:
- id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
@retry(
retry=retry_if_exception_type(DeadlockDetectedError),
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=2, min=1, max=6),
)
@override_distributed(queued_add_data_points)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
class PGVectorDataPoint(Base):
"""
Represents a data point in a PGVector database. This class maps to a table defined by
the SQLAlchemy ORM.
It contains the following public instance variables:
- id: An identifier for the data point.
- payload: A JSON object containing additional data related to the data point.
- vector: A vector representation of the data point, configured to the specified size.
"""
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.get_async_session() as session:
pgvector_data_points = []
for data_index, data_point in enumerate(data_points):
# Check to see if data should be updated or a new data item should be created
# data_point_db = (
# await session.execute(
# select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
# )
# ).scalar_one_or_none()
# If data point exists update it, if not create a new one
# if data_point_db:
# data_point_db.id = data_point.id
# data_point_db.vector = data_vectors[data_index]
# data_point_db.payload = serialize_data(data_point.model_dump())
# pgvector_data_points.append(data_point_db)
# else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
def to_dict(obj):
return {
column.key: getattr(obj, column.key)
for column in inspect(obj).mapper.column_attrs
}
# session.add_all(pgvector_data_points)
insert_statement = insert(PGVectorDataPoint).values(
[to_dict(data_point) for data_point in pgvector_data_points]
)
insert_statement = insert_statement.on_conflict_do_nothing(index_elements=["id"])
await session.execute(insert_statement)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=DataPoint.get_embeddable_data(data_point),
)
for data_point in data_points
],
)
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return metadata.tables[collection_name]
else:
raise CollectionNotFoundError(
f"Collection '{collection_name}' not found!",
)
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
for result in results
]
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: Optional[int] = 15,
with_vector: bool = False,
include_payload: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise MissingQueryParameterError()
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
if limit is None:
async with self.get_async_session() as session:
query = select(func.count()).select_from(PGVectorDataPoint)
result = await session.execute(query)
limit = result.scalar_one()
# If limit is still 0, no need to do the search, just return empty results
if limit <= 0:
return []
# NOTE: This needs to be initialized in case search doesn't return a value
closest_items = []
# Note: Exclude payload from returned columns if not needed to optimize performance
select_columns = (
[PGVectorDataPoint]
if include_payload
else [PGVectorDataPoint.c.id, PGVectorDataPoint.c.vector]
)
# Use async session to connect to the database
async with self.get_async_session() as session:
query = select(
*select_columns,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
).order_by("similarity")
if limit > 0:
query = query.limit(limit)
# Find closest vectors to query_vector
closest_items = await session.execute(query)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items.all():
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload if include_payload else None,
"_distance": vector.similarity,
}
)
if len(vector_list) == 0:
return []
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(
id=row.get("id"),
payload=row.get("payload") if include_payload else None,
score=row.get("score"),
)
for row in vector_list
]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
include_payload: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
include_payload=include_payload,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self):
# Clean up the database if it was set up as temporary
await self.delete_database()