feature: Checkpoint during pgvector integration development
Saving state of pgvector integration development so far Feature #COG-170
This commit is contained in:
parent
c62dfdda9b
commit
268396abdc
3 changed files with 84 additions and 76 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Dict
|
||||
|
||||
from .config import get_relational_config
|
||||
from ..relational.config import get_relational_config
|
||||
|
||||
class VectorConfig(Dict):
|
||||
vector_db_url: str
|
||||
|
|
@ -30,7 +30,7 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
embedding_engine = embedding_engine
|
||||
)
|
||||
elif config["vector_db_provider"] == "pgvector":
|
||||
from .pgvector import PGVectorAdapter
|
||||
from .pgvector.PGVectorAdapter import PGVectorAdapter
|
||||
|
||||
# Get configuration for postgres database
|
||||
relational_config = get_relational_config()
|
||||
|
|
@ -43,7 +43,10 @@ def create_vector_engine(config: VectorConfig, embedding_engine):
|
|||
db_name = config["vector_db_name"]
|
||||
|
||||
connection_string = f"postgresql+asyncpg://{db_username}:{db_password}@{db_host}:{db_port}/{db_name}"
|
||||
return PGVectorAdapter(connection_string)
|
||||
return PGVectorAdapter(connection_string,
|
||||
config["vector_db_key"],
|
||||
embedding_engine
|
||||
)
|
||||
else:
|
||||
from .lancedb.LanceDBAdapter import LanceDBAdapter
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,21 @@ from ..vector_db_interface import VectorDBInterface, DataPoint
|
|||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase, mapped_column
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
from ...relational.sqlalchemy.SqlAlchemyAdapter import SQLAlchemyAdapter
|
||||
|
||||
|
||||
# Define the models
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||
async def create_vector_extension(self):
|
||||
async with self.get_async_session() as session:
|
||||
await session.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
||||
def __init__(self, connection_string: str,
|
||||
api_key: Optional[str],
|
||||
embedding_engine: EmbeddingEngine
|
||||
|
|
@ -22,17 +31,13 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
|
||||
self.engine = create_async_engine(connection_string)
|
||||
self.sessionmaker = async_sessionmaker(bind=self.engine, expire_on_commit=False)
|
||||
|
||||
# Create pgvector extension in postgres
|
||||
with engine.begin() as connection:
|
||||
connection.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
||||
self.create_vector_extension()
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
connection = await self.get_async_session()
|
||||
async with self.engine.begin() as connection:
|
||||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
|
|
@ -46,7 +51,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
payload: mapped_column(payload_schema)
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
connection = await self.get_async_session()
|
||||
async with self.engine.begin() as connection:
|
||||
return await connection.create_table(
|
||||
name = collection_name,
|
||||
schema = PGVectorDataPoint,
|
||||
|
|
@ -54,8 +59,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
connection = await self.get_async_session()
|
||||
|
||||
async with self.engine.begin() as connection:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name,
|
||||
|
|
@ -88,7 +92,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await collection.add(pgvector_data_points)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_async_session()
|
||||
async with self.engine.begin() as connection:
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
|
|
@ -116,7 +120,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
connection = await self.get_async_session()
|
||||
async with self.engine.begin() as connection:
|
||||
collection = await connection.open_table(collection_name)
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
|
@ -160,7 +164,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
connection = await self.get_async_session()
|
||||
async with self.engine.begin() as connection:
|
||||
collection = await connection.open_table(collection_name)
|
||||
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
from .PGVectorAdapter import PGVectorAdapter
|
||||
Loading…
Add table
Reference in a new issue