Merge pull request #360 from topoteretes/fix-pgvector-search
Fix pgvector search
This commit is contained in:
commit
ec38404e95
2 changed files with 9 additions and 23 deletions
|
|
@ -24,19 +24,6 @@ class IndexSchema(DataPoint):
|
|||
"index_fields": ["text"]
|
||||
}
|
||||
|
||||
def singleton(class_):
|
||||
# Note: Using this singleton as a decorator to a class removes
|
||||
# the option to use class methods for that class
|
||||
instances = {}
|
||||
|
||||
def getinstance(*args, **kwargs):
|
||||
if class_ not in instances:
|
||||
instances[class_] = class_(*args, **kwargs)
|
||||
return instances[class_]
|
||||
|
||||
return getinstance
|
||||
|
||||
@singleton
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
|
||||
def __init__(
|
||||
|
|
@ -51,6 +38,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
self.engine = create_async_engine(self.db_uri)
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
# Has to be imported at class level
|
||||
# Functions reading tables from database need to know what a Vector column type is
|
||||
from pgvector.sqlalchemy import Vector
|
||||
self.Vector = Vector
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
|
|
@ -70,7 +62,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
if not await self.has_collection(collection_name):
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
|
@ -80,7 +71,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
id: Mapped[data_point_types["id"]]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
|
|
@ -108,7 +99,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
class PGVectorDataPoint(Base):
|
||||
__tablename__ = collection_name
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
|
@ -118,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
id: Mapped[type(data_points[0].id)]
|
||||
payload = Column(JSON)
|
||||
vector = Column(Vector(vector_size))
|
||||
vector = Column(self.Vector(vector_size))
|
||||
|
||||
def __init__(self, id, payload, vector):
|
||||
self.id = id
|
||||
|
|
|
|||
8
poetry.lock
generated
8
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
|
|
@ -2898,8 +2898,6 @@ optional = false
|
|||
python-versions = "*"
|
||||
files = [
|
||||
{file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"},
|
||||
{file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"},
|
||||
{file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -8879,6 +8877,4 @@ weaviate = ["weaviate-client"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9.0,<3.12"
|
||||
|
||||
|
||||
content-hash = "11a43b99fb231db46cb07d72cb19b6ffde1a263862122c3f53e759b618ce18b7"
|
||||
content-hash = "af91e3dcf6a8927ed938fe3f78172a5f1e0c0f9c8fbcbc76767b0e0d84645c9e"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue