Lancedb async lock (#1222)

<!-- .github/pull_request_template.md -->

## Description
<!-- Provide a clear description of the changes in this PR -->

## DCO Affirmation
I affirm that all code in every commit of this pull request conforms to
the terms of the Topoteretes Developer Certificate of Origin.

---------

Co-authored-by: Vasilije <8619304+Vasilije1990@users.noreply.github.com>
This commit is contained in:
Igor Ilic 2025-08-12 14:46:15 +02:00 committed by GitHub
parent c33536685d
commit a75a79f012
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 55 additions and 62 deletions

View file

@ -53,6 +53,7 @@ class SQLAlchemyAdapter:
self.engine = create_async_engine( self.engine = create_async_engine(
connection_string, connection_string,
poolclass=NullPool, poolclass=NullPool,
connect_args={"timeout": 30},
) )
else: else:
self.engine = create_async_engine( self.engine = create_async_engine(

View file

@ -51,6 +51,7 @@ class LanceDBAdapter(VectorDBInterface):
self.url = url self.url = url
self.api_key = api_key self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.VECTOR_DB_LOCK = asyncio.Lock()
async def get_connection(self): async def get_connection(self):
""" """
@ -127,12 +128,14 @@ class LanceDBAdapter(VectorDBInterface):
payload: payload_schema payload: payload_schema
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
connection = await self.get_connection() async with self.VECTOR_DB_LOCK:
return await connection.create_table( if not await self.has_collection(collection_name):
name=collection_name, connection = await self.get_connection()
schema=LanceDataPoint, return await connection.create_table(
exist_ok=True, name=collection_name,
) schema=LanceDataPoint,
exist_ok=True,
)
async def get_collection(self, collection_name: str): async def get_collection(self, collection_name: str):
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
@ -145,10 +148,12 @@ class LanceDBAdapter(VectorDBInterface):
payload_schema = type(data_points[0]) payload_schema = type(data_points[0])
if not await self.has_collection(collection_name): if not await self.has_collection(collection_name):
await self.create_collection( async with self.VECTOR_DB_LOCK:
collection_name, if not await self.has_collection(collection_name):
payload_schema, await self.create_collection(
) collection_name,
payload_schema,
)
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)
@ -188,12 +193,13 @@ class LanceDBAdapter(VectorDBInterface):
for (data_point_index, data_point) in enumerate(data_points) for (data_point_index, data_point) in enumerate(data_points)
] ]
await ( async with self.VECTOR_DB_LOCK:
collection.merge_insert("id") await (
.when_matched_update_all() collection.merge_insert("id")
.when_not_matched_insert_all() .when_matched_update_all()
.execute(lance_data_points) .when_not_matched_insert_all()
) .execute(lance_data_points)
)
async def retrieve(self, collection_name: str, data_point_ids: list[str]): async def retrieve(self, collection_name: str, data_point_ids: list[str]):
collection = await self.get_collection(collection_name) collection = await self.get_collection(collection_name)

View file

@ -54,6 +54,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
self.api_key = api_key self.api_key = api_key
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.db_uri: str = connection_string self.db_uri: str = connection_string
self.VECTOR_DB_LOCK = asyncio.Lock()
relational_db = get_relational_engine() relational_db = get_relational_engine()
@ -124,40 +125,41 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
data_point_types = get_type_hints(DataPoint) data_point_types = get_type_hints(DataPoint)
vector_size = self.embedding_engine.get_vector_size() vector_size = self.embedding_engine.get_vector_size()
if not await self.has_collection(collection_name): async with self.VECTOR_DB_LOCK:
if not await self.has_collection(collection_name):
class PGVectorDataPoint(Base): class PGVectorDataPoint(Base):
""" """
Represent a point in a vector data space with associated data and vector representation. Represent a point in a vector data space with associated data and vector representation.
This class inherits from Base and is associated with a database table defined by This class inherits from Base and is associated with a database table defined by
__tablename__. It maintains the following public methods and instance variables: __tablename__. It maintains the following public methods and instance variables:
- __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance. - __init__(self, id, payload, vector): Initializes a new PGVectorDataPoint instance.
Instance variables: Instance variables:
- id: Identifier for the data point, defined by data_point_types. - id: Identifier for the data point, defined by data_point_types.
- payload: JSON data associated with the data point. - payload: JSON data associated with the data point.
- vector: Vector representation of the data point, with size defined by vector_size. - vector: Vector representation of the data point, with size defined by vector_size.
""" """
__tablename__ = collection_name __tablename__ = collection_name
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
# PGVector requires one column to be the primary key # PGVector requires one column to be the primary key
id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True) id: Mapped[data_point_types["id"]] = mapped_column(primary_key=True)
payload = Column(JSON) payload = Column(JSON)
vector = Column(self.Vector(vector_size)) vector = Column(self.Vector(vector_size))
def __init__(self, id, payload, vector): def __init__(self, id, payload, vector):
self.id = id self.id = id
self.payload = payload self.payload = payload
self.vector = vector self.vector = vector
async with self.engine.begin() as connection: async with self.engine.begin() as connection:
if len(Base.metadata.tables.keys()) > 0: if len(Base.metadata.tables.keys()) > 0:
await connection.run_sync( await connection.run_sync(
Base.metadata.create_all, tables=[PGVectorDataPoint.__table__] Base.metadata.create_all, tables=[PGVectorDataPoint.__table__]
) )
@retry( @retry(
retry=retry_if_exception_type(DeadlockDetectedError), retry=retry_if_exception_type(DeadlockDetectedError),

View file

@ -1,8 +0,0 @@
# class PineconeVectorDB(VectorDB):
# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.init_pinecone(self.index_name)
#
# def init_pinecone(self, index_name):
# # Pinecone initialization logic
# pass

View file

@ -132,7 +132,6 @@ class QDrantAdapter(VectorDBInterface):
def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None): def __init__(self, url, api_key, embedding_engine: EmbeddingEngine, qdrant_path=None):
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
if qdrant_path is not None: if qdrant_path is not None:
self.qdrant_path = qdrant_path self.qdrant_path = qdrant_path
else: else:

View file

@ -1,5 +1,3 @@
import asyncio
from cognee.shared.logging_utils import get_logger from cognee.shared.logging_utils import get_logger
from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException from cognee.infrastructure.databases.exceptions.EmbeddingException import EmbeddingException
@ -8,9 +6,6 @@ from cognee.infrastructure.engine import DataPoint
logger = get_logger("index_data_points") logger = get_logger("index_data_points")
# A single lock shared by all coroutines
vector_index_lock = asyncio.Lock()
async def index_data_points(data_points: list[DataPoint]): async def index_data_points(data_points: list[DataPoint]):
created_indexes = {} created_indexes = {}
@ -27,11 +22,9 @@ async def index_data_points(data_points: list[DataPoint]):
index_name = f"{data_point_type.__name__}_{field_name}" index_name = f"{data_point_type.__name__}_{field_name}"
# Add async lock to make sure two different coroutines won't create a table at the same time if index_name not in created_indexes:
async with vector_index_lock: await vector_engine.create_vector_index(data_point_type.__name__, field_name)
if index_name not in created_indexes: created_indexes[index_name] = True
await vector_engine.create_vector_index(data_point_type.__name__, field_name)
created_indexes[index_name] = True
if index_name not in index_points: if index_name not in index_points:
index_points[index_name] = [] index_points[index_name] = []