diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index cd275506a..8faf1cd6d 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -24,19 +24,6 @@ class IndexSchema(DataPoint): "index_fields": ["text"] } -def singleton(class_): - # Note: Using this singleton as a decorator to a class removes - # the option to use class methods for that class - instances = {} - - def getinstance(*args, **kwargs): - if class_ not in instances: - instances[class_] = class_(*args, **kwargs) - return instances[class_] - - return getinstance - -@singleton class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): def __init__( @@ -51,6 +38,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.engine = create_async_engine(self.db_uri) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) + # Has to be imported at class level + # Functions reading tables from database need to know what a Vector column type is + from pgvector.sqlalchemy import Vector + self.Vector = Vector + async def embed_data(self, data: list[str]) -> list[list[float]]: return await self.embedding_engine.embed_text(data) @@ -70,7 +62,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): if not await self.has_collection(collection_name): - from pgvector.sqlalchemy import Vector class PGVectorDataPoint(Base): __tablename__ = collection_name __table_args__ = {"extend_existing": True} @@ -80,7 +71,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) id: Mapped[data_point_types["id"]] payload = Column(JSON) - vector = Column(Vector(vector_size)) + vector = Column(self.Vector(vector_size)) def __init__(self, id, payload, vector): self.id = id @@ -108,7 +99,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): vector_size = self.embedding_engine.get_vector_size() - from pgvector.sqlalchemy import Vector class PGVectorDataPoint(Base): __tablename__ = collection_name __table_args__ = {"extend_existing": True} @@ -118,7 +108,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) id: Mapped[type(data_points[0].id)] payload = Column(JSON) - vector = Column(Vector(vector_size)) + vector = Column(self.Vector(vector_size)) def __init__(self, id, payload, vector): self.id = id