refactor: Add formatting to PGVector Adapter
Formatted PGVectorAdapter Refactor #COG-170
This commit is contained in:
parent
325e6cd654
commit
2cd255768e
1 changed files with 61 additions and 36 deletions
|
|
@ -15,6 +15,7 @@ from ...relational.ModelBase import Base
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
# TODO: Find better location for function
|
# TODO: Find better location for function
|
||||||
def serialize_datetime(data):
|
def serialize_datetime(data):
|
||||||
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
"""Recursively convert datetime objects in dictionaries/lists to ISO format."""
|
||||||
|
|
@ -27,11 +28,14 @@ def serialize_datetime(data):
|
||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
def __init__(self, connection_string: str,
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_string: str,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine
|
embedding_engine: EmbeddingEngine,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
@ -45,9 +49,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
async def has_collection(self, collection_name: str) -> bool:
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
#TODO: Switch to using ORM instead of raw query
|
# TODO: Switch to using ORM instead of raw query
|
||||||
result = await connection.execute(
|
result = await connection.execute(
|
||||||
text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';")
|
text(
|
||||||
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
tables = result.fetchall()
|
tables = result.fetchall()
|
||||||
for table in tables:
|
for table in tables:
|
||||||
|
|
@ -55,17 +61,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
class PGVectorDataPoint(Base):
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {'extend_existing': True}
|
__table_args__ = {"extend_existing": True}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
primary_key: Mapped[int] = mapped_column(
|
||||||
|
primary_key=True, autoincrement=True
|
||||||
|
)
|
||||||
id: Mapped[data_point_types["id"]]
|
id: Mapped[data_point_types["id"]]
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
vector = Column(Vector(vector_size))
|
vector = Column(Vector(vector_size))
|
||||||
|
|
@ -77,14 +85,18 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__])
|
await connection.run_sync(
|
||||||
|
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||||
|
)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(
|
||||||
|
self, collection_name: str, data_points: List[DataPoint]
|
||||||
|
):
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
await self.create_collection(
|
||||||
collection_name = collection_name,
|
collection_name=collection_name,
|
||||||
payload_schema = type(data_points[0].payload),
|
payload_schema=type(data_points[0].payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
data_vectors = await self.embed_data(
|
||||||
|
|
@ -95,9 +107,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
class PGVectorDataPoint(Base):
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {'extend_existing': True}
|
__table_args__ = {"extend_existing": True}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
primary_key: Mapped[int] = mapped_column(
|
||||||
|
primary_key=True, autoincrement=True
|
||||||
|
)
|
||||||
id: Mapped[type(data_points[0].id)]
|
id: Mapped[type(data_points[0].id)]
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
vector = Column(Vector(vector_size))
|
vector = Column(Vector(vector_size))
|
||||||
|
|
@ -109,10 +123,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
pgvector_data_points = [
|
pgvector_data_points = [
|
||||||
PGVectorDataPoint(
|
PGVectorDataPoint(
|
||||||
id = data_point.id,
|
id=data_point.id,
|
||||||
vector = data_vectors[data_index],
|
vector=data_vectors[data_index],
|
||||||
payload = serialize_datetime(data_point.payload.dict())
|
payload=serialize_datetime(data_point.payload.dict()),
|
||||||
) for (data_index, data_point) in enumerate(data_points)
|
)
|
||||||
|
for (data_index, data_point) in enumerate(data_points)
|
||||||
]
|
]
|
||||||
|
|
||||||
session.add_all(pgvector_data_points)
|
session.add_all(pgvector_data_points)
|
||||||
|
|
@ -127,18 +142,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
|
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
|
||||||
result = await session.execute(query, {"id": data_point_ids[0]})
|
result = await session.execute(query, {"id": data_point_ids[0]})
|
||||||
else:
|
else:
|
||||||
query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)")
|
query = text(
|
||||||
|
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
|
||||||
|
)
|
||||||
result = await session.execute(query, {"ids": data_point_ids})
|
result = await session.execute(query, {"ids": data_point_ids})
|
||||||
|
|
||||||
# Fetch all rows
|
# Fetch all rows
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(id=row["id"], payload=row["payload"], score=0)
|
||||||
id=row["id"],
|
|
||||||
payload=row["payload"],
|
|
||||||
score=0
|
|
||||||
)
|
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -162,22 +175,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
try:
|
try:
|
||||||
PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine)
|
PGVectorDataPoint = Table(
|
||||||
|
collection_name, Base.metadata, autoload_with=self.engine
|
||||||
|
)
|
||||||
|
|
||||||
closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by('similarity').limit(limit))
|
closest_items = await session.execute(
|
||||||
|
select(
|
||||||
|
PGVectorDataPoint,
|
||||||
|
PGVectorDataPoint.c.vector.cosine_distance(query_vector).label(
|
||||||
|
"similarity"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.order_by("similarity")
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
vector_list = []
|
vector_list = []
|
||||||
# Extract distances and find min/max for normalization
|
# Extract distances and find min/max for normalization
|
||||||
for vector in closest_items:
|
for vector in closest_items:
|
||||||
#TODO: Add normalization of similarity score
|
# TODO: Add normalization of similarity score
|
||||||
vector_list.append(vector)
|
vector_list.append(vector)
|
||||||
|
|
||||||
# Create and return ScoredResult objects
|
# Create and return ScoredResult objects
|
||||||
return [
|
return [
|
||||||
ScoredResult(
|
ScoredResult(
|
||||||
id=str(row.id),
|
id=str(row.id), payload=row.payload, score=row.similarity
|
||||||
payload=row.payload,
|
|
||||||
score=row.similarity
|
|
||||||
)
|
)
|
||||||
for row in vector_list
|
for row in vector_list
|
||||||
]
|
]
|
||||||
|
|
@ -196,12 +218,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||||
|
|
||||||
return asyncio.gather(
|
return asyncio.gather(
|
||||||
*[self.search(
|
*[
|
||||||
collection_name = collection_name,
|
self.search(
|
||||||
query_vector = query_vector,
|
collection_name=collection_name,
|
||||||
limit = limit,
|
query_vector=query_vector,
|
||||||
with_vector = with_vectors,
|
limit=limit,
|
||||||
) for query_vector in query_vectors]
|
with_vector=with_vectors,
|
||||||
|
)
|
||||||
|
for query_vector in query_vectors
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue