refactor: Add formatting to PGVector Adapter

Formatted PGVectorAdapter

Refactor #COG-170
This commit is contained in:
Igor Ilic 2024-10-18 14:46:33 +02:00
parent 325e6cd654
commit 2cd255768e

View file

@ -15,6 +15,7 @@ from ...relational.ModelBase import Base
from datetime import datetime from datetime import datetime
# TODO: Find better location for function # TODO: Find better location for function
def serialize_datetime(data): def serialize_datetime(data):
"""Recursively convert datetime objects in dictionaries/lists to ISO format.""" """Recursively convert datetime objects in dictionaries/lists to ISO format."""
@ -27,11 +28,14 @@ def serialize_datetime(data):
else: else:
return data return data
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(self, connection_string: str, def __init__(
self,
connection_string: str,
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine embedding_engine: EmbeddingEngine,
): ):
self.api_key = api_key self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
@ -45,9 +49,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
#TODO: Switch to using ORM instead of raw query # TODO: Switch to using ORM instead of raw query
result = await connection.execute( result = await connection.execute(
text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';") text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
) )
tables = result.fetchall() tables = result.fetchall()
for table in tables: for table in tables:
@ -55,17 +61,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
return True return True
return False return False
async def create_collection(self, collection_name: str, payload_schema = None): async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {'extend_existing': True} __table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[data_point_types["id"]] id: Mapped[data_point_types["id"]]
payload = Column(JSON) payload = Column(JSON)
vector = Column(Vector(vector_size)) vector = Column(Vector(vector_size))
@ -77,14 +85,18 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0: if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]) await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(
self, collection_name: str, data_points: List[DataPoint]
):
async with self.get_async_session() as session: async with self.get_async_session() as session:
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( await self.create_collection(
collection_name = collection_name, collection_name=collection_name,
payload_schema = type(data_points[0].payload), payload_schema=type(data_points[0].payload),
) )
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
@ -95,9 +107,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {'extend_existing': True} __table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) primary_key: Mapped[int] = mapped_column(
primary_key=True, autoincrement=True
)
id: Mapped[type(data_points[0].id)] id: Mapped[type(data_points[0].id)]
payload = Column(JSON) payload = Column(JSON)
vector = Column(Vector(vector_size)) vector = Column(Vector(vector_size))
@ -109,10 +123,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
pgvector_data_points = [ pgvector_data_points = [
PGVectorDataPoint( PGVectorDataPoint(
id = data_point.id, id=data_point.id,
vector = data_vectors[data_index], vector=data_vectors[data_index],
payload = serialize_datetime(data_point.payload.dict()) payload=serialize_datetime(data_point.payload.dict()),
) for (data_index, data_point) in enumerate(data_points) )
for (data_index, data_point) in enumerate(data_points)
] ]
session.add_all(pgvector_data_points) session.add_all(pgvector_data_points)
@ -127,18 +142,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query = text(f"SELECT * FROM {collection_name} WHERE id = :id") query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
result = await session.execute(query, {"id": data_point_ids[0]}) result = await session.execute(query, {"id": data_point_ids[0]})
else: else:
query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)") query = text(
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
)
result = await session.execute(query, {"ids": data_point_ids}) result = await session.execute(query, {"ids": data_point_ids})
# Fetch all rows # Fetch all rows
rows = result.fetchall() rows = result.fetchall()
return [ return [
ScoredResult( ScoredResult(id=row["id"], payload=row["payload"], score=0)
id=row["id"],
payload=row["payload"],
score=0
)
for row in rows for row in rows
] ]
except Exception as e: except Exception as e:
@ -162,22 +175,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
try: try:
PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine) PGVectorDataPoint = Table(
collection_name, Base.metadata, autoload_with=self.engine
)
closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by('similarity').limit(limit)) closest_items = await session.execute(
select(
PGVectorDataPoint,
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
"similarity"
),
)
.order_by("similarity")
.limit(limit)
)
vector_list = [] vector_list = []
# Extract distances and find min/max for normalization # Extract distances and find min/max for normalization
for vector in closest_items: for vector in closest_items:
#TODO: Add normalization of similarity score # TODO: Add normalization of similarity score
vector_list.append(vector) vector_list.append(vector)
# Create and return ScoredResult objects # Create and return ScoredResult objects
return [ return [
ScoredResult( ScoredResult(
id=str(row.id), id=str(row.id), payload=row.payload, score=row.similarity
payload=row.payload,
score=row.similarity
) )
for row in vector_list for row in vector_list
] ]
@ -196,12 +218,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
query_vectors = await self.embedding_engine.embed_text(query_texts) query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather( return asyncio.gather(
*[self.search( *[
collection_name = collection_name, self.search(
query_vector = query_vector, collection_name=collection_name,
limit = limit, query_vector=query_vector,
with_vector = with_vectors, limit=limit,
) for query_vector in query_vectors] with_vector=with_vectors,
)
for query_vector in query_vectors
]
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):