refactor: Change raw SQL queries to SQLalchemy ORM for PGVectorAdapter

Changed raw SQL quries to use SQLalchemy ORM for PGVectorAdapter

Refactor #COG-170
This commit is contained in:
Igor Ilic 2024-10-21 12:59:24 +02:00
parent d2772d22b8
commit 240c660eac

View file

@ -1,13 +1,13 @@
from typing import List, Optional, get_type_hints, Any, Dict from typing import List, Optional, get_type_hints
from sqlalchemy import text, select from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy import JSON, Column, Table
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
import asyncio import asyncio
from ..vector_db_interface import VectorDBInterface, DataPoint from ..vector_db_interface import VectorDBInterface, DataPoint
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
@ -49,16 +49,12 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# TODO: Switch to using ORM instead of raw query # Load the schema information into the MetaData object
result = await connection.execute( await connection.run_sync(Base.metadata.reflect)
text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" if collection_name in Base.metadata.tables:
)
)
tables = result.fetchall()
for table in tables:
if collection_name == table[0]:
return True return True
else:
return False return False
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema=None):
@ -133,30 +129,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
session.add_all(pgvector_data_points) session.add_all(pgvector_data_points)
await session.commit() await session.commit()
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def get_table(self, collection_name: str) -> Table:
async with AsyncSession(self.engine) as session: """
try: Dynamically loads a table using the given collection name
# Construct the SQL query with an async engine.
# TODO: Switch to using ORM instead of raw query """
if len(data_point_ids) == 1: async with self.engine.begin() as connection:
query = text(f"SELECT * FROM {collection_name} WHERE id = :id") await connection.run_sync(Base.metadata.reflect) # Reflect the metadata
result = await session.execute(query, {"id": data_point_ids[0]}) if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name]
else: else:
query = text( raise ValueError(f"Table '{collection_name}' not found.")
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
)
result = await session.execute(query, {"ids": data_point_ids})
# Fetch all rows async def retrieve(self, collection_name: str, data_point_ids: List[str]):
rows = result.fetchall() async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [ return [
ScoredResult(id=row["id"], payload=row["payload"], score=0) ScoredResult(id=result.id, payload=result.payload, score=0)
for row in rows for result in results
] ]
except Exception as e:
print(f"Error retrieving data: {e}")
return []
async def search( async def search(
self, self,
@ -175,10 +173,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
try: try:
PGVectorDataPoint = Table( # Get PGVectorDataPoint Table from database
collection_name, Base.metadata, autoload_with=self.engine PGVectorDataPoint = await self.get_table(collection_name)
)
# Find closest vectors to query_vector
closest_items = await session.execute( closest_items = await session.execute(
select( select(
PGVectorDataPoint, PGVectorDataPoint,
@ -230,7 +228,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary