testing
This commit is contained in:
parent
2c2f3ce453
commit
54a0791d7c
1 changed files with 359 additions and 0 deletions
|
|
@ -0,0 +1,359 @@
|
|||
import asyncio
|
||||
from os import path
|
||||
import lancedb
|
||||
from pydantic import BaseModel
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from typing import Generic, List, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from cognee.infrastructure.databases.exceptions import MissingQueryParameterError
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
from cognee.infrastructure.engine.utils import parse_id
|
||||
from cognee.infrastructure.files.storage import get_file_storage
|
||||
from cognee.modules.storage.utils import copy_model, get_own_properties
|
||||
from cognee.infrastructure.databases.vector.exceptions import CollectionNotFoundError
|
||||
|
||||
from ..embeddings.EmbeddingEngine import EmbeddingEngine
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
from ..utils import normalize_distances
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
|
||||
|
||||
class IndexSchema(DataPoint):
|
||||
"""
|
||||
Represents a schema for an index data point containing an ID and text.
|
||||
|
||||
Attributes:
|
||||
|
||||
- id: A string representing the unique identifier for the data point.
|
||||
- text: A string representing the content of the data point.
|
||||
- metadata: A dictionary with default index fields for the schema, currently configured
|
||||
to include 'text'.
|
||||
"""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
|
||||
metadata: dict = {"index_fields": ["text"]}
|
||||
|
||||
|
||||
class LanceDBAdapter(VectorDBInterface):
|
||||
name = "LanceDB"
|
||||
url: str
|
||||
api_key: str
|
||||
connection: lancedb.AsyncConnection = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: Optional[str],
|
||||
api_key: Optional[str],
|
||||
embedding_engine: EmbeddingEngine,
|
||||
):
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.embedding_engine = embedding_engine
|
||||
self.VECTOR_DB_LOCK = asyncio.Lock()
|
||||
|
||||
async def get_connection(self):
|
||||
"""
|
||||
Establishes and returns a connection to the LanceDB.
|
||||
|
||||
If a connection already exists, it will return the existing connection.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- lancedb.AsyncConnection: An active connection to the LanceDB.
|
||||
"""
|
||||
if self.connection is None:
|
||||
self.connection = await lancedb.connect_async(self.url, api_key=self.api_key)
|
||||
|
||||
return self.connection
|
||||
|
||||
async def embed_data(self, data: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Embeds the provided textual data into vector representation.
|
||||
|
||||
Uses the embedding engine to convert the list of strings into a list of float vectors.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- data (list[str]): A list of strings representing the data to be embedded.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- list[list[float]]: A list of embedded vectors corresponding to the input data.
|
||||
"""
|
||||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
Checks if the specified collection exists in the LanceDB.
|
||||
|
||||
Returns True if the collection is present, otherwise False.
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
|
||||
- collection_name (str): The name of the collection to check.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
|
||||
- bool: True if the collection exists, otherwise False.
|
||||
"""
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
return collection_name in collection_names
|
||||
|
||||
async def create_collection(self, collection_name: str, payload_schema: BaseModel):
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
payload_schema = self.get_data_point_schema(payload_schema)
|
||||
data_point_types = get_type_hints(payload_schema)
|
||||
|
||||
class LanceDataPoint(LanceModel):
|
||||
"""
|
||||
Represents a data point in the Lance model with an ID, vector, and associated payload.
|
||||
|
||||
The class inherits from LanceModel and defines the following public attributes:
|
||||
- id: A unique identifier for the data point.
|
||||
- vector: A vector representing the data point in a specified dimensional space.
|
||||
- payload: Additional data or metadata associated with the data point.
|
||||
"""
|
||||
|
||||
id: data_point_types["id"]
|
||||
vector: Vector(vector_size)
|
||||
payload: payload_schema
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
connection = await self.get_connection()
|
||||
return await connection.create_table(
|
||||
name=collection_name,
|
||||
schema=LanceDataPoint,
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
async def get_collection(self, collection_name: str):
|
||||
if not await self.has_collection(collection_name):
|
||||
raise CollectionNotFoundError(f"Collection '{collection_name}' not found!")
|
||||
|
||||
connection = await self.get_connection()
|
||||
return await connection.open_table(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: list[DataPoint]):
|
||||
payload_schema = type(data_points[0])
|
||||
|
||||
if not await self.has_collection(collection_name):
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
if not await self.has_collection(collection_name):
|
||||
await self.create_collection(
|
||||
collection_name,
|
||||
payload_schema,
|
||||
)
|
||||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
)
|
||||
|
||||
IdType = TypeVar("IdType")
|
||||
PayloadSchema = TypeVar("PayloadSchema")
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
||||
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
|
||||
"""
|
||||
Represents a data point in the Lance model with an ID, vector, and payload.
|
||||
|
||||
This class encapsulates a data point consisting of an identifier, a vector representing
|
||||
the data, and an associated payload, allowing for operations and manipulations specific
|
||||
to the Lance data structure.
|
||||
"""
|
||||
|
||||
id: IdType
|
||||
vector: Vector(vector_size)
|
||||
payload: PayloadSchema
|
||||
|
||||
def create_lance_data_point(data_point: DataPoint, vector: list[float]) -> LanceDataPoint:
|
||||
properties = get_own_properties(data_point)
|
||||
properties["id"] = str(properties["id"])
|
||||
|
||||
return LanceDataPoint[str, self.get_data_point_schema(type(data_point))](
|
||||
id=str(data_point.id),
|
||||
vector=vector,
|
||||
payload=properties,
|
||||
)
|
||||
|
||||
lance_data_points = [
|
||||
create_lance_data_point(data_point, data_vectors[data_point_index])
|
||||
for (data_point_index, data_point) in enumerate(data_points)
|
||||
]
|
||||
|
||||
async with self.VECTOR_DB_LOCK:
|
||||
await (
|
||||
collection.merge_insert("id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(lance_data_points)
|
||||
)
|
||||
|
||||
async def retrieve(self, collection_name: str, data_point_ids: list[str]):
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if len(data_point_ids) == 1:
|
||||
results = await collection.query().where(f"id = '{data_point_ids[0]}'").to_pandas()
|
||||
else:
|
||||
results = await collection.query().where(f"id IN {tuple(data_point_ids)}").to_pandas()
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=0,
|
||||
)
|
||||
for result in results.to_dict("index").values()
|
||||
]
|
||||
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_text: str = None,
|
||||
query_vector: List[float] = None,
|
||||
limit: int = 15,
|
||||
with_vector: bool = False,
|
||||
normalized: bool = True,
|
||||
):
|
||||
if query_text is None and query_vector is None:
|
||||
raise MissingQueryParameterError()
|
||||
|
||||
if query_text and not query_vector:
|
||||
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
|
||||
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
if limit == 0:
|
||||
limit = await collection.count_rows()
|
||||
|
||||
# LanceDB search will break if limit is 0 so we must return
|
||||
if limit == 0:
|
||||
return []
|
||||
|
||||
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
|
||||
|
||||
result_values = list(results.to_dict("index").values())
|
||||
|
||||
if not result_values:
|
||||
return []
|
||||
|
||||
normalized_values = normalize_distances(result_values)
|
||||
|
||||
return [
|
||||
ScoredResult(
|
||||
id=parse_id(result["id"]),
|
||||
payload=result["payload"],
|
||||
score=normalized_values[value_index],
|
||||
)
|
||||
for value_index, result in enumerate(result_values)
|
||||
]
|
||||
|
||||
async def batch_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_texts: List[str],
|
||||
limit: int = None,
|
||||
with_vectors: bool = False,
|
||||
):
|
||||
query_vectors = await self.embedding_engine.embed_text(query_texts)
|
||||
|
||||
return await asyncio.gather(
|
||||
*[
|
||||
self.search(
|
||||
collection_name=collection_name,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
with_vector=with_vectors,
|
||||
)
|
||||
for query_vector in query_vectors
|
||||
]
|
||||
)
|
||||
|
||||
async def delete_data_points(self, collection_name: str, data_point_ids: list[str]):
|
||||
collection = await self.get_collection(collection_name)
|
||||
|
||||
# Delete one at a time to avoid commit conflicts
|
||||
for data_point_id in data_point_ids:
|
||||
await collection.delete(f"id = '{data_point_id}'")
|
||||
|
||||
async def create_vector_index(self, index_name: str, index_property_name: str):
|
||||
await self.create_collection(
|
||||
f"{index_name}_{index_property_name}", payload_schema=IndexSchema
|
||||
)
|
||||
|
||||
async def index_data_points(
|
||||
self, index_name: str, index_property_name: str, data_points: list[DataPoint]
|
||||
):
|
||||
await self.create_data_points(
|
||||
f"{index_name}_{index_property_name}",
|
||||
[
|
||||
IndexSchema(
|
||||
id=str(data_point.id),
|
||||
text=getattr(data_point, data_point.metadata["index_fields"][0]),
|
||||
)
|
||||
for data_point in data_points
|
||||
],
|
||||
)
|
||||
|
||||
async def prune(self):
|
||||
connection = await self.get_connection()
|
||||
collection_names = await connection.table_names()
|
||||
|
||||
for collection_name in collection_names:
|
||||
collection = await self.get_collection(collection_name)
|
||||
await collection.delete("id IS NOT NULL")
|
||||
await connection.drop_table(collection_name)
|
||||
|
||||
if self.url.startswith("/"):
|
||||
db_dir_path = path.dirname(self.url)
|
||||
db_file_name = path.basename(self.url)
|
||||
await get_file_storage(db_dir_path).remove_all(db_file_name)
|
||||
|
||||
def get_data_point_schema(self, model_type: BaseModel):
|
||||
related_models_fields = []
|
||||
|
||||
for field_name, field_config in model_type.model_fields.items():
|
||||
if hasattr(field_config, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif hasattr(field_config.annotation, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif (
|
||||
get_origin(field_config.annotation) == Union
|
||||
or get_origin(field_config.annotation) is list
|
||||
):
|
||||
models_list = get_args(field_config.annotation)
|
||||
if any(hasattr(model, "model_fields") for model in models_list):
|
||||
related_models_fields.append(field_name)
|
||||
elif models_list and any(get_args(model) is DataPoint for model in models_list):
|
||||
related_models_fields.append(field_name)
|
||||
elif models_list and any(
|
||||
submodel is DataPoint for submodel in get_args(models_list[0])
|
||||
):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
elif get_origin(field_config.annotation) == Optional:
|
||||
model = get_args(field_config.annotation)
|
||||
if hasattr(model, "model_fields"):
|
||||
related_models_fields.append(field_name)
|
||||
|
||||
return copy_model(
|
||||
model_type,
|
||||
include_fields={
|
||||
"id": (str, ...),
|
||||
},
|
||||
exclude_fields=["metadata"] + related_models_fields,
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue