refactor: Change raw SQL queries to SQLalchemy ORM for PGVectorAdapter
Changed raw SQL quries to use SQLalchemy ORM for PGVectorAdapter Refactor #COG-170
This commit is contained in:
parent
d2772d22b8
commit
240c660eac
1 changed files with 45 additions and 40 deletions
|
|
@ -1,13 +1,13 @@
|
||||||
from typing import List, Optional, get_type_hints, Any, Dict
|
from typing import List, Optional, get_type_hints
|
||||||
from sqlalchemy import text, select
|
from sqlalchemy import JSON, Column, Table, select, delete
|
||||||
from sqlalchemy import JSON, Column, Table
|
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ..vector_db_interface import VectorDBInterface, DataPoint
|
from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||||
from pgvector.sqlalchemy import Vector
|
from pgvector.sqlalchemy import Vector
|
||||||
|
|
||||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
|
|
@ -49,17 +49,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
async def has_collection(self, collection_name: str) -> bool:
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
# TODO: Switch to using ORM instead of raw query
|
# Load the schema information into the MetaData object
|
||||||
result = await connection.execute(
|
await connection.run_sync(Base.metadata.reflect)
|
||||||
text(
|
|
||||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
if collection_name in Base.metadata.tables:
|
||||||
)
|
return True
|
||||||
)
|
else:
|
||||||
tables = result.fetchall()
|
return False
|
||||||
for table in tables:
|
|
||||||
if collection_name == table[0]:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema=None):
|
async def create_collection(self, collection_name: str, payload_schema=None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
|
|
@ -133,30 +129,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
session.add_all(pgvector_data_points)
|
session.add_all(pgvector_data_points)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_table(self, collection_name: str) -> Table:
|
||||||
|
"""
|
||||||
|
Dynamically loads a table using the given collection name
|
||||||
|
with an async engine.
|
||||||
|
"""
|
||||||
|
async with self.engine.begin() as connection:
|
||||||
|
await connection.run_sync(Base.metadata.reflect) # Reflect the metadata
|
||||||
|
if collection_name in Base.metadata.tables:
|
||||||
|
return Base.metadata.tables[collection_name]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Table '{collection_name}' not found.")
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: List[str]):
|
||||||
async with AsyncSession(self.engine) as session:
|
async with self.get_async_session() as session:
|
||||||
try:
|
# Get PGVectorDataPoint Table from database
|
||||||
# Construct the SQL query
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
# TODO: Switch to using ORM instead of raw query
|
|
||||||
if len(data_point_ids) == 1:
|
|
||||||
query = text(f"SELECT * FROM {collection_name} WHERE id = :id")
|
|
||||||
result = await session.execute(query, {"id": data_point_ids[0]})
|
|
||||||
else:
|
|
||||||
query = text(
|
|
||||||
f"SELECT * FROM {collection_name} WHERE id = ANY(:ids)"
|
|
||||||
)
|
|
||||||
result = await session.execute(query, {"ids": data_point_ids})
|
|
||||||
|
|
||||||
# Fetch all rows
|
results = await session.execute(
|
||||||
rows = result.fetchall()
|
select(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||||
|
)
|
||||||
|
results = results.all()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ScoredResult(id=row["id"], payload=row["payload"], score=0)
|
ScoredResult(id=result.id, payload=result.payload, score=0)
|
||||||
for row in rows
|
for result in results
|
||||||
]
|
]
|
||||||
except Exception as e:
|
|
||||||
print(f"Error retrieving data: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
|
|
@ -175,10 +173,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
# Use async session to connect to the database
|
# Use async session to connect to the database
|
||||||
async with self.get_async_session() as session:
|
async with self.get_async_session() as session:
|
||||||
try:
|
try:
|
||||||
PGVectorDataPoint = Table(
|
# Get PGVectorDataPoint Table from database
|
||||||
collection_name, Base.metadata, autoload_with=self.engine
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
)
|
|
||||||
|
|
||||||
|
# Find closest vectors to query_vector
|
||||||
closest_items = await session.execute(
|
closest_items = await session.execute(
|
||||||
select(
|
select(
|
||||||
PGVectorDataPoint,
|
PGVectorDataPoint,
|
||||||
|
|
@ -230,7 +228,14 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
pass
|
async with self.get_async_session() as session:
|
||||||
|
# Get PGVectorDataPoint Table from database
|
||||||
|
PGVectorDataPoint = await self.get_table(collection_name)
|
||||||
|
results = await session.execute(
|
||||||
|
delete(PGVectorDataPoint).where(PGVectorDataPoint.c.id.in_(data_point_ids))
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return results
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue