diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index c592a5b7a..9b8ef1c88 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -15,6 +15,7 @@ from ...relational.ModelBase import Base from datetime import datetime + # TODO: Find better location for function def serialize_datetime(data): """Recursively convert datetime objects in dictionaries/lists to ISO format.""" @@ -27,11 +28,14 @@ def serialize_datetime(data): else: return data + class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): - def __init__(self, connection_string: str, + def __init__( + self, + connection_string: str, api_key: Optional[str], - embedding_engine: EmbeddingEngine + embedding_engine: EmbeddingEngine, ): self.api_key = api_key self.embedding_engine = embedding_engine @@ -45,9 +49,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def has_collection(self, collection_name: str) -> bool: async with self.engine.begin() as connection: - #TODO: Switch to using ORM instead of raw query + # TODO: Switch to using ORM instead of raw query result = await connection.execute( - text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';") + text( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" + ) ) tables = result.fetchall() for table in tables: @@ -55,17 +61,19 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): return True return False - async def create_collection(self, collection_name: str, payload_schema = None): + async def create_collection(self, collection_name: str, payload_schema=None): data_point_types = get_type_hints(DataPoint) vector_size = self.embedding_engine.get_vector_size() if not await self.has_collection(collection_name): - + class PGVectorDataPoint(Base): __tablename__ = collection_name - __table_args__ = {'extend_existing': True} + __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[int] = mapped_column( + primary_key=True, autoincrement=True + ) id: Mapped[data_point_types["id"]] payload = Column(JSON) vector = Column(Vector(vector_size)) @@ -77,14 +85,18 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async with self.engine.begin() as connection: if len(Base.metadata.tables.keys()) > 0: - await connection.run_sync(Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]) + await connection.run_sync( + Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] + ) - async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): + async def create_data_points( + self, collection_name: str, data_points: List[DataPoint] + ): async with self.get_async_session() as session: if not await self.has_collection(collection_name): await self.create_collection( - collection_name = collection_name, - payload_schema = type(data_points[0].payload), + collection_name=collection_name, + payload_schema=type(data_points[0].payload), ) data_vectors = await self.embed_data( @@ -95,9 +107,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorDataPoint(Base): __tablename__ = collection_name - __table_args__ = {'extend_existing': True} + __table_args__ = {"extend_existing": True} # PGVector requires one column to be the primary key - primary_key: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + primary_key: Mapped[int] = mapped_column( + primary_key=True, autoincrement=True + ) id: Mapped[type(data_points[0].id)] payload = Column(JSON) vector = Column(Vector(vector_size)) @@ -109,10 +123,11 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): pgvector_data_points = [ PGVectorDataPoint( - id = data_point.id, - vector = data_vectors[data_index], - payload = serialize_datetime(data_point.payload.dict()) - ) for (data_index, data_point) in enumerate(data_points) + id=data_point.id, + vector=data_vectors[data_index], + payload=serialize_datetime(data_point.payload.dict()), + ) + for (data_index, data_point) in enumerate(data_points) ] session.add_all(pgvector_data_points) @@ -127,18 +142,16 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query = text(f"SELECT * FROM {collection_name} WHERE id = :id") result = await session.execute(query, {"id": data_point_ids[0]}) else: - query = text(f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)") + query = text( + f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)" + ) result = await session.execute(query, {"ids": data_point_ids}) # Fetch all rows rows = result.fetchall() return [ - ScoredResult( - id=row["id"], - payload=row["payload"], - score=0 - ) + ScoredResult(id=row["id"], payload=row["payload"], score=0) for row in rows ] except Exception as e: @@ -162,22 +175,31 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Use async session to connect to the database async with self.get_async_session() as session: try: - PGVectorDataPoint = Table(collection_name, Base.metadata, autoload_with=self.engine) + PGVectorDataPoint = Table( + collection_name, Base.metadata, autoload_with=self.engine + ) - closest_items = await session.execute(select(PGVectorDataPoint, PGVectorDataPoint.c.vector.cosine_distance(query_vector).label('similarity')).order_by('similarity').limit(limit)) + closest_items = await session.execute( + select( + PGVectorDataPoint, + PGVectorDataPoint.c.vector.cosine_distance(query_vector).label( + "similarity" + ), + ) + .order_by("similarity") + .limit(limit) + ) vector_list = [] # Extract distances and find min/max for normalization for vector in closest_items: - #TODO: Add normalization of similarity score + # TODO: Add normalization of similarity score vector_list.append(vector) # Create and return ScoredResult objects return [ ScoredResult( - id=str(row.id), - payload=row.payload, - score=row.similarity + id=str(row.id), payload=row.payload, score=row.similarity ) for row in vector_list ] @@ -196,12 +218,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): query_vectors = await self.embedding_engine.embed_text(query_texts) return asyncio.gather( - *[self.search( - collection_name = collection_name, - query_vector = query_vector, - limit = limit, - with_vector = with_vectors, - ) for query_vector in query_vectors] + *[ + self.search( + collection_name=collection_name, + query_vector=query_vector, + limit=limit, + with_vector=with_vectors, + ) + for query_vector in query_vectors + ] ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):