<!-- .github/pull_request_template.md --> ## Description Resolve issues with duplicate chunks for PGVector ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin.
335 lines
13 KiB
Python
335 lines
13 KiB
Python
import asyncio
|
|
from typing import List, Optional, get_type_hints
|
|
from uuid import UUID, uuid4
|
|
|
|
from sqlalchemy.orm import Mapped, mapped_column
|
|
from sqlalchemy import JSON, Column, Table, select, delete, MetaData
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
|
|
from cognee.exceptions import InvalidValueError
|
|
from cognee.infrastructure.databases.exceptions import EntityNotFoundError
|
|
from cognee.infrastructure.engine import DataPoint
|
|
from cognee.infrastructure.engine.utils import parse_id
|
|
from cognee.infrastructure.databases.relational import get_relational_engine
|
|
|
|
from ...relational.ModelBase import Base
|
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
|
from ..models.ScoredResult import ScoredResult
|
|
from ..vector_db_interface import VectorDBInterface
|
|
from .serialize_data import serialize_data
|
|
from ..utils import normalize_distances
|
|
|
|
|
|
class IndexSchema(DataPoint):
|
|
text: str
|
|
|
|
metadata: dict = {"index_fields": ["text"]}
|
|
|
|
|
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|
def __init__(
|
|
self,
|
|
connection_string: str,
|
|
api_key: Optional[str],
|
|
embedding_engine: EmbeddingEngine,
|
|
):
|
|
self.api_key = api_key
|
|
self.embedding_engine = embedding_engine
|
|
self.db_uri: str = connection_string
|
|
|
|
relational_db = get_relational_engine()
|
|
|
|
# If postgreSQL is used we must use the same engine and sessionmaker
|
|
if relational_db.engine.dialect.name == "postgresql":
|
|
self.engine = relational_db.engine
|
|
self.sessionmaker = relational_db.sessionmaker
|
|
else:
|
|
# If not create new instances of engine and sessionmaker
|
|
self.engine = create_async_engine(self.db_uri)
|
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
|
|
|
# Has to be imported at class level
|
|
# Functions reading tables from database need to know what a Vector column type is
|
|
from pgvector.sqlalchemy import Vector
|
|
|
|
self.Vector = Vector
|
|
|
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
|
return await self.embedding_engine.embed_text(data)
|
|
|
|
async def has_collection(self, collection_name: str) -> bool:
|
|
async with self.engine.begin() as connection:
|
|
# Create a MetaData instance to load table information
|
|
metadata = MetaData()
|
|
# Load table information from schema into MetaData
|
|
await connection.run_sync(metadata.reflect)
|
|
|
|
if collection_name in metadata.tables:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
async def create_collection(self, collection_name: str, payload_schema=None):
|
|
data_point_types = get_type_hints(DataPoint)
|
|
vector_size = self.embedding_engine.get_vector_size()
|
|
|
|
if not await self.has_collection(collection_name):
|
|
|
|
class PGVectorDataPoint(Base):
|
|
__tablename__ = collection_name
|
|
__table_args__ = {"extend_existing": True}
|
|
# PGVector requires one column to be the primary key
|
|
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
|
id: Mapped[data_point_types["id"]]
|
|
payload = Column(JSON)
|
|
vector = Column(self.Vector(vector_size))
|
|
|
|
def __init__(self, id, payload, vector):
|
|
self.id = id
|
|
self.payload = payload
|
|
self.vector = vector
|
|
|
|
async with self.engine.begin() as connection:
|
|
if len(Base.metadata.tables.keys()) > 0:
|
|
await connection.run_sync(
|
|
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
|
)
|
|
|
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
|
data_point_types = get_type_hints(DataPoint)
|
|
if not await self.has_collection(collection_name):
|
|
await self.create_collection(
|
|
collection_name=collection_name,
|
|
payload_schema=type(data_points[0]),
|
|
)
|
|
|
|
data_vectors = await self.embed_data(
|
|
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
|
)
|
|
|
|
vector_size = self.embedding_engine.get_vector_size()
|
|
|
|
class PGVectorDataPoint(Base):
|
|
__tablename__ = collection_name
|
|
__table_args__ = {"extend_existing": True}
|
|
# PGVector requires one column to be the primary key
|
|
primary_key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
|
|
id: Mapped[data_point_types["id"]]
|
|
payload = Column(JSON)
|
|
vector = Column(self.Vector(vector_size))
|
|
|
|
def __init__(self, id, payload, vector):
|
|
self.id = id
|
|
self.payload = payload
|
|
self.vector = vector
|
|
|
|
async with self.get_async_session() as session:
|
|
pgvector_data_points = []
|
|
|
|
for data_index, data_point in enumerate(data_points):
|
|
# Check to see if data should be updated or a new data item should be created
|
|
data_point_db = (
|
|
await session.execute(
|
|
select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id)
|
|
)
|
|
).scalar_one_or_none()
|
|
|
|
# If data point exists update it, if not create a new one
|
|
if data_point_db:
|
|
data_point_db.id = data_point.id
|
|
data_point_db.vector = data_vectors[data_index]
|
|
data_point_db.payload = serialize_data(data_point.model_dump())
|
|
pgvector_data_points.append(data_point_db)
|
|
else:
|
|
pgvector_data_points.append(
|
|
PGVectorDataPoint(
|
|
id=data_point.id,
|
|
vector=data_vectors[data_index],
|
|
payload=serialize_data(data_point.model_dump()),
|
|
)
|
|
)
|
|
|
|
session.add_all(pgvector_data_points)
|
|
await session.commit()
|
|
|
|
async def create_vector_index(self, index_name: str, index_property_name: str):
|
|
await self.create_collection(f"{index_name}_{index_property_name}")
|
|
|
|
async def index_data_points(
|
|
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
|
):
|
|
await self.create_data_points(
|
|
f"{index_name}_{index_property_name}",
|
|
[
|
|
IndexSchema(
|
|
id=data_point.id,
|
|
text=DataPoint.get_embeddable_data(data_point),
|
|
)
|
|
for data_point in data_points
|
|
],
|
|
)
|
|
|
|
async def get_table(self, collection_name: str) -> Table:
|
|
"""
|
|
Dynamically loads a table using the given collection name
|
|
with an async engine.
|
|
"""
|
|
async with self.engine.begin() as connection:
|
|
# Create a MetaData instance to load table information
|
|
metadata = MetaData()
|
|
# Load table information from schema into MetaData
|
|
await connection.run_sync(metadata.reflect)
|
|
if collection_name in metadata.tables:
|
|
return metadata.tables[collection_name]
|
|
else:
|
|
raise EntityNotFoundError(message=f"Table '{collection_name}' not found.")
|
|
|
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
|
# Get PGVectorDataPoint Table from database
|
|
PGVectorDataPoint = await self.get_table(collection_name)
|
|
|
|
async with self.get_async_session() as session:
|
|
results = await session.execute(
|
|
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
|
)
|
|
results = results.all()
|
|
|
|
return [
|
|
ScoredResult(id=parse_id(result.id), payload=result.payload, score=0)
|
|
for result in results
|
|
]
|
|
|
|
async def get_distance_from_collection_elements(
|
|
self,
|
|
collection_name: str,
|
|
query_text: str = None,
|
|
query_vector: List[float] = None,
|
|
with_vector: bool = False,
|
|
) -> List[ScoredResult]:
|
|
if query_text is None and query_vector is None:
|
|
raise ValueError("One of query_text or query_vector must be provided!")
|
|
|
|
if query_text and not query_vector:
|
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
|
|
|
try:
|
|
# Get PGVectorDataPoint Table from database
|
|
PGVectorDataPoint = await self.get_table(collection_name)
|
|
|
|
# Use async session to connect to the database
|
|
async with self.get_async_session() as session:
|
|
# Find closest vectors to query_vector
|
|
closest_items = await session.execute(
|
|
select(
|
|
PGVectorDataPoint,
|
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
|
"similarity"
|
|
),
|
|
).order_by("similarity")
|
|
)
|
|
|
|
vector_list = []
|
|
|
|
# Extract distances and find min/max for normalization
|
|
for vector in closest_items:
|
|
# TODO: Add normalization of similarity score
|
|
vector_list.append(vector)
|
|
|
|
# Create and return ScoredResult objects
|
|
return [
|
|
ScoredResult(id=parse_id(str(row.id)), payload=row.payload, score=row.similarity)
|
|
for row in vector_list
|
|
]
|
|
except EntityNotFoundError:
|
|
# Ignore if collection does not exist
|
|
return []
|
|
|
|
async def search(
|
|
self,
|
|
collection_name: str,
|
|
query_text: Optional[str] = None,
|
|
query_vector: Optional[List[float]] = None,
|
|
limit: int = 5,
|
|
with_vector: bool = False,
|
|
) -> List[ScoredResult]:
|
|
if query_text is None and query_vector is None:
|
|
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
|
|
|
|
if query_text and not query_vector:
|
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
|
|
|
# Get PGVectorDataPoint Table from database
|
|
PGVectorDataPoint = await self.get_table(collection_name)
|
|
|
|
closest_items = []
|
|
|
|
# Use async session to connect to the database
|
|
async with self.get_async_session() as session:
|
|
# Find closest vectors to query_vector
|
|
closest_items = await session.execute(
|
|
select(
|
|
PGVectorDataPoint,
|
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label("similarity"),
|
|
)
|
|
.order_by("similarity")
|
|
.limit(limit)
|
|
)
|
|
|
|
vector_list = []
|
|
|
|
# Extract distances and find min/max for normalization
|
|
for vector in closest_items:
|
|
vector_list.append(
|
|
{
|
|
"id": parse_id(str(vector.id)),
|
|
"payload": vector.payload,
|
|
"_distance": vector.similarity,
|
|
}
|
|
)
|
|
|
|
# Normalize vector distance and add this as score information to vector_list
|
|
normalized_values = normalize_distances(vector_list)
|
|
for i in range(0, len(normalized_values)):
|
|
vector_list[i]["score"] = normalized_values[i]
|
|
|
|
# Create and return ScoredResult objects
|
|
return [
|
|
ScoredResult(id=row.get("id"), payload=row.get("payload"), score=row.get("score"))
|
|
for row in vector_list
|
|
]
|
|
|
|
async def batch_search(
|
|
self,
|
|
collection_name: str,
|
|
query_texts: List[str],
|
|
limit: int = None,
|
|
with_vectors: bool = False,
|
|
):
|
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
|
|
|
return await asyncio.gather(
|
|
*[
|
|
self.search(
|
|
collection_name=collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
with_vector=with_vectors,
|
|
)
|
|
for query_vector in query_vectors
|
|
]
|
|
)
|
|
|
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
|
async with self.get_async_session() as session:
|
|
# Get PGVectorDataPoint Table from database
|
|
PGVectorDataPoint = await self.get_table(collection_name)
|
|
results = await session.execute(
|
|
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
|
)
|
|
await session.commit()
|
|
return results
|
|
|
|
async def prune(self):
|
|
# Clean up the database if it was set up as temporary
|
|
await self.delete_database()
|