From 240c660eac3a294c0d2cf235e6934de2ab145f3d Mon Sep 17 00:00:00 2001 From: Igor Ilic Date: Mon, 21 Oct 2024 12:59:24 +0200 Subject: [PATCH] refactor: Change raw SQL queries to SQLalchemy ORM for PGVectorAdapter Changed raw SQL quries to use SQLalchemy ORM for PGVectorAdapter Refactor #COG-170 --- .../vector/pgvector/PGVectorAdapter.py | 85 ++++++++++--------- 1 file changed, 45 insertions(+), 40 deletions(-) diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 9b8ef1c88..b9a716663 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -1,13 +1,13 @@ -from typing import List, Optional, get_type_hints, Any, Dict -from sqlalchemy import text, select -from sqlalchemy import JSON, Column, Table +from typing import List, Optional, get_type_hints +from sqlalchemy import JSON, Column, Table, select, delete + from ..models.ScoredResult import ScoredResult import asyncio from ..vector_db_interface import VectorDBInterface, DataPoint from sqlalchemy.orm import Mapped, mapped_column from ..embeddings.EmbeddingEngine import EmbeddingEngine -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from pgvector.sqlalchemy import Vector from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter @@ -49,17 +49,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): async def has_collection(self, collection_name: str) -> bool: async with self.engine.begin() as connection: - # TODO: Switch to using ORM instead of raw query - result = await connection.execute( - text( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" - ) - ) - tables = result.fetchall() - for table in tables: - if collection_name == table[0]: - return True - return False + # Load the schema information into the MetaData object + await connection.run_sync(Base.metadata.reflect) + + if collection_name in Base.metadata.tables: + return True + else: + return False async def create_collection(self, collection_name: str, payload_schema=None): data_point_types = get_type_hints(DataPoint) @@ -133,30 +129,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): session.add_all(pgvector_data_points) await session.commit() + async def get_table(self, collection_name: str) -> Table: + """ + Dynamically loads a table using the given collection name + with an async engine. + """ + async with self.engine.begin() as connection: + await connection.run_sync(Base.metadata.reflect) # Reflect the metadata + if collection_name in Base.metadata.tables: + return Base.metadata.tables[collection_name] + else: + raise ValueError(f"Table '{collection_name}' not found.") + async def retrieve(self, collection_name: str, data_point_ids: List[str]): - async with AsyncSession(self.engine) as session: - try: - # Construct the SQL query - # TODO: Switch to using ORM instead of raw query - if len(data_point_ids) == 1: - query = text(f"SELECT * FROM {collection_name} WHERE id = :id") - result = await session.execute(query, {"id": data_point_ids[0]}) - else: - query = text( - f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)" - ) - result = await session.execute(query, {"ids": data_point_ids}) + async with self.get_async_session() as session: + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) - # Fetch all rows - rows = result.fetchall() + results = await session.execute( + select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) + ) + results = results.all() - return [ - ScoredResult(id=row["id"], payload=row["payload"], score=0) - for row in rows - ] - except Exception as e: - print(f"Error retrieving data: {e}") - return [] + return [ + ScoredResult(id=result.id, payload=result.payload, score=0) + for result in results + ] async def search( self, @@ -175,10 +173,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): # Use async session to connect to the database async with self.get_async_session() as session: try: - PGVectorDataPoint = Table( - collection_name, Base.metadata, autoload_with=self.engine - ) + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + # Find closest vectors to query_vector closest_items = await session.execute( select( PGVectorDataPoint, @@ -230,7 +228,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): - pass + async with self.get_async_session() as session: + # Get PGVectorDataPoint Table from database + PGVectorDataPoint = await self.get_table(collection_name) + results = await session.execute( + delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids)) + ) + await session.commit() + return results async def prune(self): # Clean up the database if it was set up as temporary