fix: make get_embeddable_data static

This commit is contained in:
Boris Arzentar 2024-12-03 21:47:23 +01:00
parent b89a4b8054
commit 0b8b270933
9 changed files with 32 additions and 22 deletions

View file

@ -87,7 +87,7 @@ class LanceDBAdapter(VectorDBInterface):
collection = await connection.open_table(collection_name)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
IdType = TypeVar("IdType")

View file

@ -102,7 +102,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
)
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
vector_size = self.embedding_engine.get_vector_size()
@ -143,7 +143,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
text = DataPoint.get_embeddable_data(data_point),
) for data_point in data_points
])

View file

@ -102,7 +102,9 @@ class QDrantAdapter(VectorDBInterface):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client()
data_vectors = await self.embed_data([data_point.get_embeddable_data() for data_point in data_points])
data_vectors = await self.embed_data([
DataPoint.get_embeddable_data(data_point) for data_point in data_points
])
def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct(

View file

@ -83,7 +83,7 @@ class WeaviateAdapter(VectorDBInterface):
from weaviate.classes.data import DataObject
data_vectors = await self.embed_data(
[data_point.get_embeddable_data() for data_point in data_points]
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
)
def convert_to_weaviate_data_points(data_point: DataPoint):
@ -116,12 +116,20 @@ class WeaviateAdapter(VectorDBInterface):
)
else:
data_point: DataObject = data_points[0]
return collection.data.update(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
if collection.data.exists(data_point.uuid):
return collection.data.update(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
else:
return collection.data.insert(
uuid = data_point.uuid,
vector = data_point.vector,
properties = data_point.properties,
references = data_point.references,
)
except Exception as error:
logger.error("Error creating data points: %s", str(error))
raise error
@ -133,7 +141,7 @@ class WeaviateAdapter(VectorDBInterface):
await self.create_data_points(f"{index_name}_{index_property_name}", [
IndexSchema(
id = data_point.id,
text = data_point.get_embeddable_data(),
text = DataPoint.get_embeddable_data(data_point),
) for data_point in data_points
])

View file

@ -19,10 +19,11 @@ class DataPoint(BaseModel):
# class Config:
# underscore_attrs_are_private = True
def get_embeddable_data(self):
if self._metadata and len(self._metadata["index_fields"]) > 0 \
and hasattr(self, self._metadata["index_fields"][0]):
attribute = getattr(self, self._metadata["index_fields"][0])
@classmethod
def get_embeddable_data(self, data_point):
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \
and hasattr(data_point, data_point._metadata["index_fields"][0]):
attribute = getattr(data_point, data_point._metadata["index_fields"][0])
if isinstance(attribute, str):
return attribute.strip()

View file

@ -20,8 +20,8 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
destination = get_dlt_destination()
pipeline = dlt.pipeline(
pipeline_name="file_load_from_filesystem",
destination=destination,
pipeline_name = "file_load_from_filesystem",
destination = destination,
)
@dlt.resource(standalone = True, merge_key = "id")

View file

@ -1,4 +1,3 @@
import asyncio
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.engine import DataPoint

4
poetry.lock generated
View file

@ -4659,8 +4659,8 @@ files = [
[package.dependencies]
numpy = [
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
]
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
@ -7840,4 +7840,4 @@ weaviate = ["weaviate-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.0,<3.12"
content-hash = "b63498e7aa23cfe29d8bea1fc29b0fe4a5f1a9e8ae5ec75d45b4bd20438e26f9"
content-hash = "e2360f4be222743bb83b1e7316185c5f62bd73c0baaab3eee984e1c84f1cea65"

View file

@ -41,7 +41,7 @@ aiosqlite = "^0.20.0"
pandas = "2.0.3"
filetype = "^1.2.0"
nltk = "^3.8.1"
dlt = {extras = ["sqlalchemy"], version = "^1.3.0"}
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
aiofiles = "^23.2.1"
qdrant-client = "^1.9.0"
graphistry = "^0.33.5"