fix: Resolve issue regrading not having Vector column type defined when using vector search
Issue happens when search is called in a session without previously adding data or creating tables as an import of Vector column type was missing Fix
This commit is contained in:
parent
92ecd8a024
commit
599e1d478b
1 changed files with 7 additions and 17 deletions
|
|
@ -24,19 +24,6 @@ class IndexSchema(DataPoint):
|
|||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
def singleton(class_):
|
||||
# Note: Using this singleton as a decorator to a class removes
|
||||
# the option to use class methods for that class
|
||||
instances = {}
|
||||
|
||||
def getinstance(*args, **kwargs):
|
||||
if class_ not in instances:
|
||||
instances[class_] = class_(*args, **kwargs)
|
||||
return instances[class_]
|
||||
|
||||
return getinstance
|
||||
|
||||
@singleton
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -51,6 +38,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self.engine = create_async_engine(self.db_uri)
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
# Has to be imported at class level
|
||||
# Functions reading tables from database need to know what a Vector column type is
|
||||
from pgvector.sqlalchemy import Vector
|
||||
self.Vector = Vector
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
|
|
@ -70,7 +62,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
|
@ -80,7 +71,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
|
|
@ -108,7 +99,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
|
@ -118,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
id: Mapped[type(data_points[0].id)]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue