fix: Resolve issue regrading not having Vector column type defined when using vector search

Issue happens when search is called in a session without previously adding data or creating tables as an import of Vector column type was missing

Fix
This commit is contained in:
Igor Ilic 2024-12-12 13:37:18 +01:00
parent 92ecd8a024
commit 599e1d478b

View file

@ -24,19 +24,6 @@ class IndexSchema(DataPoint):
"index_fields": ["text"]
}
def singleton(class_):
# Note: Using this singleton as a decorator to a class removes
# the option to use class methods for that class
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
def __init__(
@ -51,6 +38,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.engine = create_async_engine(self.db_uri)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
# Has to be imported at class level
# Functions reading tables from database need to know what a Vector column type is
from pgvector.sqlalchemy import Vector
self.Vector = Vector
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
@ -70,7 +62,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if not await self.has_collection(collection_name):
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -80,7 +71,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
)
id: Mapped[data_point_types["id"]]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id
@ -108,7 +99,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
vector_size = self.embedding_engine.get_vector_size()
from pgvector.sqlalchemy import Vector
class PGVectorDataPoint(Base):
__tablename__ = collection_name
__table_args__ = {"extend_existing": True}
@ -118,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
)
id: Mapped[type(data_points[0].id)]
payload = Column(JSON)
vector = Column(Vector(vector_size))
vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector):
self.id = id