refactor: Add formatting to PGVector Adapter

Formatted PGVectorAdapter

Refactor #COG-170
This commit is contained in:
Igor Ilic 2024-10-18 14:46:33 +02:00
parent 325e6cd654
commit 2cd255768e

View file

@ -15,6 +15,7 @@ from ...relational.ModelBase import Base
from datetime import datetime
# TODO: Find better location for function
def serialize_datetime(data):
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
@ -27,11 +28,14 @@ def serialize_datetime(data):
else:
return data
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(self, connection_string: str,
def __init__(
self,
connection_string: str,
api_key: Optional[str],
embedding_engine: EmbeddingEngine
embedding_engine: EmbeddingEngine,
):
self.api_key = api_key
self.embedding_engine = embedding_engine
@ -45,9 +49,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection:
#TODO: Switch to using ORM instead of raw query
# TODO: Switch to using ORM instead of raw query
result = await connection.execute(
text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';")
text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
)
tables = result.fetchall()
for table in tables:
@ -55,17 +61,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
return True
return False
async def create_collection(self, collection_name: str, payload_schema = None):
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {'extend_existing': True}
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(Vector(vector_size))
@ -77,14 +85,18 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__])
await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name = collection_name,
payload_schema = type(data_points[0].payload),
collection_name=collection_name,
payload_schema=type(data_points[0].payload),
)
data_vectors = await self.embed_data(
@ -95,9 +107,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {'extend_existing': True}
__table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
@ -109,10 +123,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
pgvector_data_points = [
PGVectorDataPoint(
id = data_point.id,
vector = data_vectors[data_index],
payload = serialize_datetime(data_point.payload.dict())
) for (data_index, data_point) in enumerate(data_points)
id=data_point.id,
vector=data_vectors[data_index],
payload=serialize_datetime(data_point.payload.dict()),
)
for (data_index, data_point) in enumerate(data_points)
]
session.add_all(pgvector_data_points)
@ -127,18 +142,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
result = await session.execute(query, {"id": data_point_ids[0]})
else:
query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)")
query = text(
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
)
result = await session.execute(query, {"ids": data_point_ids})
# Fetch all rows
rows = result.fetchall()
return [
ScoredResult(
id=row["id"],
payload=row["payload"],
score=0
)
ScoredResult(id=row["id"], payload=row["payload"], score=0)
for row in rows
]
except Exception as e:
@ -162,22 +175,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Use async session to connect to the database
async with self.get_async_session() as session:
try:
PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine)
PGVectorDataPoint = Table(
collection_name, Base.metadata, autoload_with=self.engine
)
closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by('similarity').limit(limit))
closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
.limit(limit)
)
vector_list = []
# Extract distances and find min/max for normalization
for vector in closest_items:
#TODO: Add normalization of similarity score
# TODO: Add normalization of similarity score
vector_list.append(vector)
# Create and return ScoredResult objects
return [
ScoredResult(
id=str(row.id),
payload=row.payload,
score=row.similarity
id=str(row.id), payload=row.payload, score=row.similarity
)
for row in vector_list
]
@ -196,12 +218,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
*[
self.search(
collection_name=collection_name,
query_vector=query_vector,
limit=limit,
with_vector=with_vectors,
)
for query_vector in query_vectors
]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):