feature: Checkpoint during pgvector integration development
Saving state of pgvector integration development so far Feature #COG-170
This commit is contained in:
parent
c62dfdda9b
commit
268396abdc
3 changed files with 84 additions and 76 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from .config import get_relational_config
|
from ..relational.config import get_relational_config
|
||||||
|
|
||||||
class VectorConfig(Dict):
|
class VectorConfig(Dict):
|
||||||
vector_db_url: str
|
vector_db_url: str
|
||||||
|
|
@ -30,7 +30,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
embedding_engine = embedding_engine
|
embedding_engine = embedding_engine
|
||||||
)
|
)
|
||||||
elif config["vector_db_provider"] == "pgvector":
|
elif config["vector_db_provider"] == "pgvector":
|
||||||
from .pgvector import PGVectorAdapter
|
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||||
|
|
||||||
# Get configuration for postgres database
|
# Get configuration for postgres database
|
||||||
relational_config = get_relational_config()
|
relational_config = get_relational_config()
|
||||||
|
|
@ -43,7 +43,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
||||||
db_name = config["vector_db_name"]
|
db_name = config["vector_db_name"]
|
||||||
|
|
||||||
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||||
return PGVectorAdapter(connection_string)
|
return PGVectorAdapter(connection_string,
|
||||||
|
config["vector_db_key"],
|
||||||
|
embedding_engine
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,21 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
|
||||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, mapped_column
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
|
|
||||||
|
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||||
|
|
||||||
|
|
||||||
# Define the models
|
# Define the models
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
async def create_vector_extension(self):
|
||||||
|
async with self.get_async_session() as session:
|
||||||
|
await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||||
|
|
||||||
def __init__(self, connection_string: str,
|
def __init__(self, connection_string: str,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
embedding_engine: EmbeddingEngine
|
embedding_engine: EmbeddingEngine
|
||||||
|
|
@ -22,19 +31,15 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
|
|
||||||
self.engine = create_async_engine(connection_string)
|
self.engine = create_async_engine(connection_string)
|
||||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||||
|
self.create_vector_extension()
|
||||||
# Create pgvector extension in postgres
|
|
||||||
with engine.begin() as connection:
|
|
||||||
connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
||||||
|
|
||||||
|
|
||||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||||
return await self.embedding_engine.embed_text(data)
|
return await self.embedding_engine.embed_text(data)
|
||||||
|
|
||||||
async def has_collection(self, collection_name: str) -> bool:
|
async def has_collection(self, collection_name: str) -> bool:
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
collection_names = await connection.table_names()
|
collection_names = await connection.table_names()
|
||||||
return collection_name in collection_names
|
return collection_name in collection_names
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, payload_schema = None):
|
async def create_collection(self, collection_name: str, payload_schema = None):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
|
|
@ -46,61 +51,60 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
payload: mapped_column(payload_schema)
|
payload: mapped_column(payload_schema)
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
return await connection.create_table(
|
return await connection.create_table(
|
||||||
name = collection_name,
|
name = collection_name,
|
||||||
schema = PGVectorDataPoint,
|
schema = PGVectorDataPoint,
|
||||||
exist_ok = True,
|
exist_ok = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
|
if not await self.has_collection(collection_name):
|
||||||
|
await self.create_collection(
|
||||||
|
collection_name,
|
||||||
|
payload_schema = type(data_points[0].payload),
|
||||||
|
)
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
collection = await connection.open_table(collection_name)
|
||||||
await self.create_collection(
|
|
||||||
collection_name,
|
data_vectors = await self.embed_data(
|
||||||
payload_schema = type(data_points[0].payload),
|
[data_point.get_embeddable_data() for data_point in data_points]
|
||||||
)
|
)
|
||||||
|
|
||||||
collection = await connection.open_table(collection_name)
|
IdType = TypeVar("IdType")
|
||||||
|
PayloadSchema = TypeVar("PayloadSchema")
|
||||||
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]):
|
||||||
[data_point.get_embeddable_data() for data_point in data_points]
|
id: Mapped[int] = mapped_column(IdType, primary_key=True)
|
||||||
)
|
vector = mapped_column(Vector(vector_size))
|
||||||
|
payload: mapped_column(PayloadSchema)
|
||||||
|
|
||||||
IdType = TypeVar("IdType")
|
pgvector_data_points = [
|
||||||
PayloadSchema = TypeVar("PayloadSchema")
|
PGVectorDataPoint[type(data_point.id), type(data_point.payload)](
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
id = data_point.id,
|
||||||
|
vector = data_vectors[data_index],
|
||||||
|
payload = data_point.payload,
|
||||||
|
) for (data_index, data_point) in enumerate(data_points)
|
||||||
|
]
|
||||||
|
|
||||||
class PGVectorDataPoint(Base, Generic[IdType, PayloadSchema]):
|
await collection.add(pgvector_data_points)
|
||||||
id: Mapped[int] = mapped_column(IdType, primary_key=True)
|
|
||||||
vector = mapped_column(Vector(vector_size))
|
|
||||||
payload: mapped_column(PayloadSchema)
|
|
||||||
|
|
||||||
pgvector_data_points = [
|
|
||||||
PGVectorDataPoint[type(data_point.id), type(data_point.payload)](
|
|
||||||
id = data_point.id,
|
|
||||||
vector = data_vectors[data_index],
|
|
||||||
payload = data_point.payload,
|
|
||||||
) for (data_index, data_point) in enumerate(data_points)
|
|
||||||
]
|
|
||||||
|
|
||||||
await collection.add(pgvector_data_points)
|
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
collection = await connection.open_table(collection_name)
|
collection = await connection.open_table(collection_name)
|
||||||
|
|
||||||
if len(data_point_ids) == 1:
|
if len(data_point_ids) == 1:
|
||||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
||||||
else:
|
else:
|
||||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = result["id"],
|
id = result["id"],
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = 0,
|
score = 0,
|
||||||
) for result in results.to_dict("index").values()]
|
) for result in results.to_dict("index").values()]
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
|
|
@ -116,30 +120,30 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
if query_text and not query_vector:
|
if query_text and not query_vector:
|
||||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||||
|
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
collection = await connection.open_table(collection_name)
|
collection = await connection.open_table(collection_name)
|
||||||
|
|
||||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||||
|
|
||||||
result_values = list(results.to_dict("index").values())
|
result_values = list(results.to_dict("index").values())
|
||||||
|
|
||||||
min_value = 100
|
min_value = 100
|
||||||
max_value = 0
|
max_value = 0
|
||||||
|
|
||||||
for result in result_values:
|
for result in result_values:
|
||||||
value = float(result["_distance"])
|
value = float(result["_distance"])
|
||||||
if value > max_value:
|
if value > max_value:
|
||||||
max_value = value
|
max_value = value
|
||||||
if value < min_value:
|
if value < min_value:
|
||||||
min_value = value
|
min_value = value
|
||||||
|
|
||||||
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in result_values]
|
||||||
|
|
||||||
return [ScoredResult(
|
return [ScoredResult(
|
||||||
id = str(result["id"]),
|
id = str(result["id"]),
|
||||||
payload = result["payload"],
|
payload = result["payload"],
|
||||||
score = normalized_values[value_index],
|
score = normalized_values[value_index],
|
||||||
) for value_index, result in enumerate(result_values)]
|
) for value_index, result in enumerate(result_values)]
|
||||||
|
|
||||||
async def batch_search(
|
async def batch_search(
|
||||||
self,
|
self,
|
||||||
|
|
@ -160,10 +164,10 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||||
connection = await self.get_async_session()
|
async with self.engine.begin() as connection:
|
||||||
collection = await connection.open_table(collection_name)
|
collection = await connection.open_table(collection_name)
|
||||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def prune(self):
|
async def prune(self):
|
||||||
# Clean up the database if it was set up as temporary
|
# Clean up the database if it was set up as temporary
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .PGVectorAdapter import PGVectorAdapter
|
||||||
Loading…
Add table
Reference in a new issue