cognee/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py
Igor Ilic 343ac47fd4 fix: Update import location for LanceDB
Updated import path for LanceDB exceptions

Fix COG-502
2024-12-02 13:19:55 +01:00

254 lines
No EOL
9.1 KiB
Python

import inspect
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio
from uuid import UUID
import lancedb
from pydantic import BaseModel
from lancedb.pydantic import Vector, LanceModel
from cognee.exceptions import InvalidValueError
from cognee.infrastructure.engine import DataPoint
from cognee.infrastructure.files.storage import LocalStorage
from cognee.modules.storage.utils import copy_model, get_own_properties
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface
from ..utils import normalize_distances
from ..embeddings.EmbeddingEngine import EmbeddingEngine
class IndexSchema(DataPoint):
id: str
text: str
_metadata: dict = {
"index_fields": ["text"]
}
class LanceDBAdapter(VectorDBInterface):
name = "LanceDB"
url: str
api_key: str
connection: lancedb.AsyncConnection = None
def __init__(
self,
url: Optional[str],
api_key: Optional[str],
embedding_engine: EmbeddingEngine,
):
self.url = url
self.api_key = api_key
self.embedding_engine = embedding_engine
async def get_connection(self):
if self.connection is None:
self.connection = await lancedb.connect_async(self.url, api_key = self.api_key)
return self.connection
async def embed_data(self, data: list[str]) -> list[list[float]]:
return await self.embedding_engine.embed_text(data)
async def has_collection(self, collection_name: str) -> bool:
connection = await self.get_connection()
collection_names = await connection.table_names()
return collection_name in collection_names
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
vector_size = self.embedding_engine.get_vector_size()
payload_schema = self.get_data_point_schema(payload_schema)
data_point_types = get_type_hints(payload_schema)
class LanceDataPoint(LanceModel):
id: data_point_types["id"]
vector: Vector(vector_size)
payload: payload_schema
if not await self.has_collection(collection_name):
connection = await self.get_connection()
return await connection.create_table(
name = collection_name,
schema = LanceDataPoint,
exist_ok = True,
)
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
connection = await self.get_connection()
payload_schema = type(data_points[0])
payload_schema = self.get_data_point_schema(payload_schema)
if not await self.has_collection(collection_name):
await self.create_collection(
collection_name,
payload_schema,
)
collection = await connection.open_table(collection_name)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
)
IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
vector_size = self.embedding_engine.get_vector_size()
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
id: IdType
vector: Vector(vector_size)
payload: PayloadSchema
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
properties = get_own_properties(data_point)
properties["id"] = str(properties["id"])
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
id = str(data_point.id),
vector = vector,
payload = properties,
)
lance_data_points = [
create_lance_data_point(data_point, data_vectors[data_point_index])
for (data_point_index, data_point) in enumerate(data_points)
]
# TODO: This enables us to work with pydantic version but shouldn't
# stay like this, existing rows should be updated
await collection.delete("id IS NOT NULL")
original_size = await collection.count_rows()
await collection.add(lance_data_points)
new_size = await collection.count_rows()
if new_size <= original_size:
raise InvalidValueError(message=
"LanceDB create_datapoints error: data points did not get added.")
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
if len(data_point_ids) == 1:
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
else:
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
return [ScoredResult(
id = UUID(result["id"]),
payload = result["payload"],
score = 0,
) for result in results.to_dict("index").values()]
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).to_pandas()
result_values = list(results.to_dict("index").values())
normalized_values = normalize_distances(result_values)
return [ScoredResult(
id=UUID(result["id"]),
payload=result["payload"],
score=normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 5,
with_vector: bool = False,
normalized: bool = True
):
if query_text is None and query_vector is None:
raise InvalidValueError(message="One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
result_values = list(results.to_dict("index").values())
normalized_values = normalize_distances(result_values)
return [ScoredResult(
id = UUID(result["id"]),
payload = result["payload"],
score = normalized_values[value_index],
) for value_index, result in enumerate(result_values)]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vectors: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return await asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vectors,
) for query_vector in query_vectors]
)
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
if len(data_point_ids) == 1:
results = await collection.delete(f"id = '{data_point_ids[0]}'")
else:
results = await collection.delete(f"id IN {tuple(data_point_ids)}")
return results
async def create_vector_index(self, index_name: str, index_property_name: str):
await self.create_collection(f"{index_name}_{index_property_name}", payload_schema = IndexSchema)
async def index_data_points(self, index_name: str, index_property_name: str, data_points: list[DataPoint]):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = str(data_point.id),
text = getattr(data_point, data_point._metadata["index_fields"][0]),
) for data_point in data_points
])
async def prune(self):
# Clean up the database if it was set up as temporary
if self.url.startswith("/"):
LocalStorage.remove_all(self.url) # Remove the temporary directory and files inside
def get_data_point_schema(self, model_type):
return copy_model(
model_type,
include_fields = {
"id": (str, ...),
},
exclude_fields = ["_metadata"],
)