refactor: Change raw SQL queries to SQLalchemy ORM for PGVectorAdapter

Changed raw SQL quries to use SQLalchemy ORM for PGVectorAdapter

Refactor #COG-170
This commit is contained in:
Igor Ilic 2024-10-21 12:59:24 +02:00
parent d2772d22b8
commit 240c660eac

View file

@ -1,13 +1,13 @@
from typing import List, Optional, get_type_hints, Any, Dict from typing import List, Optional, get_type_hints
from sqlalchemy import text, select from sqlalchemy import JSON, Column, Table, select, delete
from sqlalchemy import JSON, Column, Table
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
import asyncio import asyncio
from ..vector_db_interface import VectorDBInterface, DataPoint from ..vector_db_interface import VectorDBInterface, DataPoint
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
@ -49,17 +49,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
# TODO: Switch to using ORM instead of raw query # Load the schema information into the MetaData object
result = await connection.execute( await connection.run_sync(Base.metadata.reflect)
text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" if collection_name in Base.metadata.tables:
) return True
) else:
tables = result.fetchall() return False
for table in tables:
if collection_name == table[0]:
return True
return False
async def create_collection(self, collection_name: str, payload_schema=None): async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
@ -133,30 +129,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
session.add_all(pgvector_data_points) session.add_all(pgvector_data_points)
await session.commit() await session.commit()
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
await connection.run_sync(Base.metadata.reflect) # Reflect the metadata
if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name]
else:
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]): async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with AsyncSession(self.engine) as session: async with self.get_async_session() as session:
try: # Get PGVectorDataPoint Table from database
# Construct the SQL query PGVectorDataPoint = await self.get_table(collection_name)
# TODO: Switch to using ORM instead of raw query
if len(data_point_ids) == 1:
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
result = await session.execute(query, {"id": data_point_ids[0]})
else:
query = text(
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
)
result = await session.execute(query, {"ids": data_point_ids})
# Fetch all rows results = await session.execute(
rows = result.fetchall() select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [ return [
ScoredResult(id=row["id"], payload=row["payload"], score=0) ScoredResult(id=result.id, payload=result.payload, score=0)
for row in rows for result in results
] ]
except Exception as e:
print(f"Error retrieving data: {e}")
return []
async def search( async def search(
self, self,
@ -175,10 +173,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Use async session to connect to the database # Use async session to connect to the database
async with self.get_async_session() as session: async with self.get_async_session() as session:
try: try:
PGVectorDataPoint = Table( # Get PGVectorDataPoint Table from database
collection_name, Base.metadata, autoload_with=self.engine PGVectorDataPoint = await self.get_table(collection_name)
)
# Find closest vectors to query_vector
closest_items = await session.execute( closest_items = await session.execute(
select( select(
PGVectorDataPoint, PGVectorDataPoint,
@ -230,7 +228,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary