Lancedb async lock (#1222)
<!-- .github/pull_request_template.md --> ## Description <!-- Provide a clear description of the changes in this PR --> ## DCO Affirmation I affirm that all code in every commit of this pull request conforms to the terms of the Topoteretes Developer Certificate of Origin. --------- Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
parent
c33536685d
commit
a75a79f012
7 changed files with 55 additions and 62 deletions
|
|
@ -53,6 +53,7 @@ class SQLAlchemyAdapter:
|
||||||
self.engine = create_async_engine(
|
self.engine = create_async_engine(
|
||||||
connection_string,
|
connection_string,
|
||||||
poolclass=NullPool,
|
poolclass=NullPool,
|
||||||
|
connect_args={"timeout": 30},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.engine = create_async_engine(
|
self.engine = create_async_engine(
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
self.url = url
|
self.url = url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
async def get_connection(self):
|
async def get_connection(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -127,12 +128,14 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
payload: payload_schema
|
payload: payload_schema
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
connection = await self.get_connection()
|
async with self.VECTOR_DB_LOCK:
|
||||||
return await connection.create_table(
|
if not await self.has_collection(collection_name):
|
||||||
name=collection_name,
|
connection = await self.get_connection()
|
||||||
schema=LanceDataPoint,
|
return await connection.create_table(
|
||||||
exist_ok=True,
|
name=collection_name,
|
||||||
)
|
schema=LanceDataPoint,
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_collection(self, collection_name: str):
|
async def get_collection(self, collection_name: str):
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
|
|
@ -145,10 +148,12 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
payload_schema = type(data_points[0])
|
payload_schema = type(data_points[0])
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
if not await self.has_collection(collection_name):
|
||||||
await self.create_collection(
|
async with self.VECTOR_DB_LOCK:
|
||||||
collection_name,
|
if not await self.has_collection(collection_name):
|
||||||
payload_schema,
|
await self.create_collection(
|
||||||
)
|
collection_name,
|
||||||
|
payload_schema,
|
||||||
|
)
|
||||||
|
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
||||||
|
|
@ -188,12 +193,13 @@ class LanceDBAdapter(VectorDBInterface):
|
||||||
for (data_point_index, data_point) in enumerate(data_points)
|
for (data_point_index, data_point) in enumerate(data_points)
|
||||||
]
|
]
|
||||||
|
|
||||||
await (
|
async with self.VECTOR_DB_LOCK:
|
||||||
collection.merge_insert("id")
|
await (
|
||||||
.when_matched_update_all()
|
collection.merge_insert("id")
|
||||||
.when_not_matched_insert_all()
|
.when_matched_update_all()
|
||||||
.execute(lance_data_points)
|
.when_not_matched_insert_all()
|
||||||
)
|
.execute(lance_data_points)
|
||||||
|
)
|
||||||
|
|
||||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||||
collection = await self.get_collection(collection_name)
|
collection = await self.get_collection(collection_name)
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
self.db_uri: str = connection_string
|
self.db_uri: str = connection_string
|
||||||
|
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
relational_db = get_relational_engine()
|
relational_db = get_relational_engine()
|
||||||
|
|
||||||
|
|
@ -124,40 +125,41 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
||||||
data_point_types = get_type_hints(DataPoint)
|
data_point_types = get_type_hints(DataPoint)
|
||||||
vector_size = self.embedding_engine.get_vector_size()
|
vector_size = self.embedding_engine.get_vector_size()
|
||||||
|
|
||||||
if not await self.has_collection(collection_name):
|
async with self.VECTOR_DB_LOCK:
|
||||||
|
if not await self.has_collection(collection_name):
|
||||||
|
|
||||||
class PGVectorDataPoint(Base):
|
class PGVectorDataPoint(Base):
|
||||||
"""
|
"""
|
||||||
Represent a point in a vector data space with associated data and vector representation.
|
Represent a point in a vector data space with associated data and vector representation.
|
||||||
|
|
||||||
This class inherits from Base and is associated with a database table defined by
|
This class inherits from Base and is associated with a database table defined by
|
||||||
__tablename__. It maintains the following public methods and instance variables:
|
__tablename__. It maintains the following public methods and instance variables:
|
||||||
|
|
||||||
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
|
||||||
|
|
||||||
Instance variables:
|
Instance variables:
|
||||||
- id: Identifier for the data point, defined by data_point_types.
|
- id: Identifier for the data point, defined by data_point_types.
|
||||||
- payload: JSON data associated with the data point.
|
- payload: JSON data associated with the data point.
|
||||||
- vector: Vector representation of the data point, with size defined by vector_size.
|
- vector: Vector representation of the data point, with size defined by vector_size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = collection_name
|
__tablename__ = collection_name
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
# PGVector requires one column to be the primary key
|
# PGVector requires one column to be the primary key
|
||||||
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
|
||||||
payload = Column(JSON)
|
payload = Column(JSON)
|
||||||
vector = Column(self.Vector(vector_size))
|
vector = Column(self.Vector(vector_size))
|
||||||
|
|
||||||
def __init__(self, id, payload, vector):
|
def __init__(self, id, payload, vector):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.vector = vector
|
self.vector = vector
|
||||||
|
|
||||||
async with self.engine.begin() as connection:
|
async with self.engine.begin() as connection:
|
||||||
if len(Base.metadata.tables.keys()) > 0:
|
if len(Base.metadata.tables.keys()) > 0:
|
||||||
await connection.run_sync(
|
await connection.run_sync(
|
||||||
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
|
||||||
)
|
)
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
retry=retry_if_exception_type(DeadlockDetectedError),
|
retry=retry_if_exception_type(DeadlockDetectedError),
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# class PineconeVectorDB(VectorDB):
|
|
||||||
# def __init__(self, *args, **kwargs):
|
|
||||||
# super().__init__(*args, **kwargs)
|
|
||||||
# self.init_pinecone(self.index_name)
|
|
||||||
#
|
|
||||||
# def init_pinecone(self, index_name):
|
|
||||||
# # Pinecone initialization logic
|
|
||||||
# pass
|
|
||||||
|
|
@ -132,7 +132,6 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None):
|
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None):
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
if qdrant_path is not None:
|
if qdrant_path is not None:
|
||||||
self.qdrant_path = qdrant_path
|
self.qdrant_path = qdrant_path
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from cognee.shared.logging_utils import get_logger
|
from cognee.shared.logging_utils import get_logger
|
||||||
|
|
||||||
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
|
||||||
|
|
@ -8,9 +6,6 @@ from cognee.infrastructure.engine import DataPoint
|
||||||
|
|
||||||
logger = get_logger("index_data_points")
|
logger = get_logger("index_data_points")
|
||||||
|
|
||||||
# A single lock shared by all coroutines
|
|
||||||
vector_index_lock = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def index_data_points(data_points: list[DataPoint]):
|
async def index_data_points(data_points: list[DataPoint]):
|
||||||
created_indexes = {}
|
created_indexes = {}
|
||||||
|
|
@ -27,11 +22,9 @@ async def index_data_points(data_points: list[DataPoint]):
|
||||||
|
|
||||||
index_name = f"{data_point_type.__name__}_{field_name}"
|
index_name = f"{data_point_type.__name__}_{field_name}"
|
||||||
|
|
||||||
# Add async lock to make sure two different coroutines won't create a table at the same time
|
if index_name not in created_indexes:
|
||||||
async with vector_index_lock:
|
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
||||||
if index_name not in created_indexes:
|
created_indexes[index_name] = True
|
||||||
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
|
|
||||||
created_indexes[index_name] = True
|
|
||||||
|
|
||||||
if index_name not in index_points:
|
if index_name not in index_points:
|
||||||
index_points[index_name] = []
|
index_points[index_name] = []
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue