cognee/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py
Igor Ilic c4a6c94675
fix: Resolve duplicate chunk issue for PGVector [COG-895] (#705)
<!-- .github/pull_request_template.md -->

## Description
Resolve issues with duplicate chunks for PGVector

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.
2025-04-07 18:03:36 +02:00

335 lines
13 KiB
Python

import asyncio
from typing import List, Optional, get_type_hints
from uuid import UUID, uuid4
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.engine.utils import parse_id
from cognee.infrastructure.databases.relational import get_relational_engine
from ...relational.ModelBase import Base
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from .serialize_data import serialize_data
from ..utils import normalize_distances
class IndexSchema(DataPoint):
text: str
metadata: dict = {"index_fields": ["text"]}
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.api_key = api_key
self.embedding_engine = embedding_engine
self.db_uri: str = connection_string
relational_db = get_relational_engine()
# If postgreSQL is used we must use the same engine and sessionmaker
if relational_db.engine.dialect.name == "postgresql":
self.engine = relational_db.engine
self.sessionmaker = relational_db.sessionmaker
else:
# If not create new instances of engine and sessionmaker
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
# Has to be imported at class level
# Functions reading tables from database need to know what a Vector column type is
from pgvector.sqlalchemy import Vector
self.Vector = Vector
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return True
else:
return False
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
data_point_types = get_type_hints(DataPoint)
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name=collection_name,
payload_schema=type(data_points[0]),
)
data_vectors = await self.embed_data(
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
self.payload = payload
self.vector = vector
async with self.get_async_session() as session:
pgvector_data_points = []
for data_index, data_point in enumerate(data_points):
# Check to see if data should be updated or a new data item should be created
data_point_db = (
await session.execute(
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
)
).scalar_one_or_none()
# If data point exists update it, if not create a new one
if data_point_db:
data_point_db.id = data_point.id
data_point_db.vector = data_vectors[data_index]
data_point_db.payload = serialize_data(data_point.model_dump())
pgvector_data_points.append(data_point_db)
else:
pgvector_data_points.append(
PGVectorDataPoint(
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_data(data_point.model_dump()),
)
)
session.add_all(pgvector_data_points)
await session.commit()
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}")
async def index_data_points(
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
):
await self.create_data_points(
f"{index_name}_{index_property_name}",
[
IndexSchema(
id=data_point.id,
text=DataPoint.get_embeddable_data(data_point),
)
for data_point in data_points
],
)
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
# Create a MetaData instance to load table information
metadata = MetaData()
# Load table information from schema into MetaData
await connection.run_sync(metadata.reflect)
if collection_name in metadata.tables:
return metadata.tables[collection_name]
else:
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
async with self.get_async_session() as session:
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
for result in results
]
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
try:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
).order_by("similarity")
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(id=parse_id(str(row.id)), payload=row.payload, score=row.similarity)
for row in vector_list
]
except EntityNotFoundError:
# Ignore if collection does not exist
return []
async def search(
self,
collection_name: str,
query_text: Optional[str] = None,
query_vector: Optional[List[float]] = None,
limit: int = 5,
with_vector: bool = False,
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
)
.order_by("similarity")
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
vector_list.append(
{
"id": parse_id(str(vector.id)),
"payload": vector.payload,
"_distance": vector.similarity,
}
)
# Normalize vector distance and add this as score information to vector_list
normalized_values = normalize_distances(vector_list)
for i in range(0, len(normalized_values)):
vector_list[i]["score"] = normalized_values[i]
# Create and return ScoredResult objects
return [
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
for row in vector_list
]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self):
# Clean up the database if it was set up as temporary
await self.delete_database()