feature: Checkpoint during pgvector integration development

Saving state of pgvector integration development so far

Feature #COG-170
This commit is contained in:
Igor Ilic 2024-10-11 17:11:05 +02:00
parent c62dfdda9b
commit 268396abdc
3 changed files with 84 additions and 76 deletions

View file

@ -1,6 +1,6 @@
from typing import Dict from typing import Dict
from .config import get_relational_config from ..relational.config import get_relational_config
class VectorConfig(Dict): class VectorConfig(Dict):
vector_db_url: str vector_db_url: str
@ -30,7 +30,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
embedding_engine = embedding_engine embedding_engine = embedding_engine
) )
elif config["vector_db_provider"] == "pgvector": elif config["vector_db_provider"] == "pgvector":
from .pgvector import PGVectorAdapter from .pgvector.PGVectorAdapter import PGVectorAdapter
# Get configuration for postgres database # Get configuration for postgres database
relational_config = get_relational_config() relational_config = get_relational_config()
@ -43,7 +43,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
db_name = config["vector_db_name"] db_name = config["vector_db_name"]
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}" connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
return PGVectorAdapter(connection_string) return PGVectorAdapter(connection_string,
config["vector_db_key"],
embedding_engine
)
else: else:
from .lancedb.LanceDBAdapter import LanceDBAdapter from .lancedb.LanceDBAdapter import LanceDBAdapter

View file

@ -6,12 +6,21 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine from ..embeddings.EmbeddingEngine import EmbeddingEngine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, mapped_column
from pgvector.sqlalchemy import Vector
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
# Define the models # Define the models
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
async def create_vector_extension(self):
async with self.get_async_session() as session:
await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
def __init__(self, connection_string: str, def __init__(self, connection_string: str,
api_key: Optional[str], api_key: Optional[str],
embedding_engine: EmbeddingEngine embedding_engine: EmbeddingEngine
@ -22,19 +31,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.engine = create_async_engine(connection_string) self.engine = create_async_engine(connection_string)
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False) self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
self.create_vector_extension()
# Create pgvector extension in postgres
with engine.begin() as connection:
connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
async def embed_data(self, data: list[str]) -> list[list[float]]: async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool: async def has_collection(self, collection_name: str) -> bool:
connection = await self.get_async_session() async with self.engine.begin() as connection:
collection_names = await connection.table_names() collection_names = await connection.table_names()
return collection_name in collection_names return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema = None): async def create_collection(self, collection_name: str, payload_schema = None):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
@ -46,61 +51,60 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
payload: mapped_column(payload_schema) payload: mapped_column(payload_schema)
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
connection = await self.get_async_session() async with self.engine.begin() as connection:
return await connection.create_table( return await connection.create_table(
name = collection_name, name = collection_name,
schema = PGVectorDataPoint, schema = PGVectorDataPoint,
exist_ok = True, exist_ok = True,
) )
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
connection = await self.get_async_session() async with self.engine.begin() as connection:
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name,
payload_schema = type(data_points[0].payload),
)
if not await self.has_collection(collection_name): collection = await connection.open_table(collection_name)
await self.create_collection(
collection_name, data_vectors = await self.embed_data(
payload_schema = type(data_points[0].payload), [data_point.get_embeddable_data() for data_point in data_points]
) )
collection = await connection.open_table(collection_name) IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size()
data_vectors = await self.embed_data( class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]):
[data_point.get_embeddable_data() for data_point in data_points] id: Mapped[int] = mapped_column(IdType, primary_key=True)
) vector = mapped_column(Vector(vector_size))
payload: mapped_column(PayloadSchema)
IdType = TypeVar("IdType") pgvector_data_points = [
PayloadSchema = TypeVar("PayloadSchema") PGVectorDataPoint[type(data_point.id), type(data_point.payload)](
vector_size = self.embedding_engine.get_vector_size() id = data_point.id,
vector = data_vectors[data_index],
payload = data_point.payload,
) for (data_index, data_point) in enumerate(data_points)
]
class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]): await collection.add(pgvector_data_points)
id: Mapped[int] = mapped_column(IdType, primary_key=True)
vector = mapped_column(Vector(vector_size))
payload: mapped_column(PayloadSchema)
pgvector_data_points = [
PGVectorDataPoint[type(data_point.id), type(data_point.payload)](
id = data_point.id,
vector = data_vectors[data_index],
payload = data_point.payload,
) for (data_index, data_point) in enumerate(data_points)
]
await collection.add(pgvector_data_points)
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_async_session() async with self.engine.begin() as connection:
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
if len(data_point_ids) == 1: if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas() results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
else: else:
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas() results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
return [ScoredResult( return [ScoredResult(
id = result["id"], id = result["id"],
payload = result["payload"], payload = result["payload"],
score = 0, score = 0,
) for result in results.to_dict("index").values()] ) for result in results.to_dict("index").values()]
async def search( async def search(
self, self,
@ -116,30 +120,30 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
if query_text and not query_vector: if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0] query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_async_session() async with self.engine.begin() as connection:
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).limit(limit).to_pandas() results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values()) result_values = list(results.to_dict("index").values())
min_value = 100 min_value = 100
max_value = 0 max_value = 0
for result in result_values: for result in result_values:
value = float(result["_distance"]) value = float(result["_distance"])
if value > max_value: if value > max_value:
max_value = value max_value = value
if value < min_value: if value < min_value:
min_value = value min_value = value
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values] normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
return [ScoredResult( return [ScoredResult(
id = str(result["id"]), id = str(result["id"]),
payload = result["payload"], payload = result["payload"],
score = normalized_values[value_index], score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)] ) for value_index, result in enumerate(result_values)]
async def batch_search( async def batch_search(
self, self,
@ -160,10 +164,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) )
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]): async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_async_session() async with self.engine.begin() as connection:
collection = await connection.open_table(collection_name) collection = await connection.open_table(collection_name)
results = await collection.delete(f"id IN {tuple(data_point_ids)}") results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results return results
async def prune(self): async def prune(self):
# Clean up the database if it was set up as temporary # Clean up the database if it was set up as temporary

View file

@ -0,0 +1 @@
from .PGVectorAdapter import PGVectorAdapter