refactor: Change raw SQL queries to SQLalchemy ORM for PGVectorAdapter

Changed raw SQL quries to use SQLalchemy ORM for PGVectorAdapter

Refactor #COG-170
This commit is contained in:
Igor Ilic 2024-10-21 12:59:24 +02:00
parent d2772d22b8
commit 240c660eac

View file

@ -1,13 +1,13 @@
from typing import List, Optional, get_type_hints, Any, Dict
from sqlalchemy import text, select
from sqlalchemy import JSON, Column, Table
from typing import List, Optional, get_type_hints
from sqlalchemy import JSON, Column, Table, select, delete
from ..models.ScoredResult import ScoredResult
import asyncio
from ..vector_db_interface import VectorDBInterface, DataPoint
from sqlalchemy.orm import Mapped, mapped_column
from ..embeddings.EmbeddingEngine import EmbeddingEngine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from pgvector.sqlalchemy import Vector
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
@ -49,17 +49,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def has_collection(self, collection_name: str) -> bool:
async with self.engine.begin() as connection:
# TODO: Switch to using ORM instead of raw query
result = await connection.execute(
text(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
)
tables = result.fetchall()
for table in tables:
if collection_name == table[0]:
return True
return False
# Load the schema information into the MetaData object
await connection.run_sync(Base.metadata.reflect)
if collection_name in Base.metadata.tables:
return True
else:
return False
async def create_collection(self, collection_name: str, payload_schema=None):
data_point_types = get_type_hints(DataPoint)
@ -133,30 +129,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
session.add_all(pgvector_data_points)
await session.commit()
async def get_table(self, collection_name: str) -> Table:
"""
Dynamically loads a table using the given collection name
with an async engine.
"""
async with self.engine.begin() as connection:
await connection.run_sync(Base.metadata.reflect) # Reflect the metadata
if collection_name in Base.metadata.tables:
return Base.metadata.tables[collection_name]
else:
raise ValueError(f"Table '{collection_name}' not found.")
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
async with AsyncSession(self.engine) as session:
try:
# Construct the SQL query
# TODO: Switch to using ORM instead of raw query
if len(data_point_ids) == 1:
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
result = await session.execute(query, {"id": data_point_ids[0]})
else:
query = text(
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
)
result = await session.execute(query, {"ids": data_point_ids})
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Fetch all rows
rows = result.fetchall()
results = await session.execute(
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
results = results.all()
return [
ScoredResult(id=row["id"], payload=row["payload"], score=0)
for row in rows
]
except Exception as e:
print(f"Error retrieving data: {e}")
return []
return [
ScoredResult(id=result.id, payload=result.payload, score=0)
for result in results
]
async def search(
self,
@ -175,10 +173,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Use async session to connect to the database
async with self.get_async_session() as session:
try:
PGVectorDataPoint = Table(
collection_name, Base.metadata, autoload_with=self.engine
)
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
# Find closest vectors to query_vector
closest_items = await session.execute(
select(
PGVectorDataPoint,
@ -230,7 +228,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
pass
async with self.get_async_session() as session:
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
results = await session.execute(
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
)
await session.commit()
return results
async def prune(self):
# Clean up the database if it was set up as temporary