fix: make get_embeddable_data static
This commit is contained in:
parent
b89a4b8054
commit
0b8b270933
9 changed files with 32 additions and 22 deletions
|
|
@ -87,7 +87,7 @@ class LanceDBAdapter(VectorDBInterface):
|
|||
collection = await connection.open_table(collection_name)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
)
|
||||
|
||||
IdType = TypeVar("IdType")
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
)
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
)
|
||||
|
||||
vector_size = self.embedding_engine.get_vector_size()
|
||||
|
|
@ -143,7 +143,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
text = DataPoint.get_embeddable_data(data_point),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -102,7 +102,9 @@ class QDrantAdapter(VectorDBInterface):
|
|||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
data_vectors = await self.embed_data([data_point.get_embeddable_data() for data_point in data_points])
|
||||
data_vectors = await self.embed_data([
|
||||
DataPoint.get_embeddable_data(data_point) for data_point in data_points
|
||||
])
|
||||
|
||||
def convert_to_qdrant_point(data_point: DataPoint):
|
||||
return models.PointStruct(
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
from weaviate.classes.data import DataObject
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
[data_point.get_embeddable_data() for data_point in data_points]
|
||||
[DataPoint.get_embeddable_data(data_point) for data_point in data_points]
|
||||
)
|
||||
|
||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||
|
|
@ -116,12 +116,20 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
)
|
||||
else:
|
||||
data_point: DataObject = data_points[0]
|
||||
return collection.data.update(
|
||||
uuid = data_point.uuid,
|
||||
vector = data_point.vector,
|
||||
properties = data_point.properties,
|
||||
references = data_point.references,
|
||||
)
|
||||
if collection.data.exists(data_point.uuid):
|
||||
return collection.data.update(
|
||||
uuid = data_point.uuid,
|
||||
vector = data_point.vector,
|
||||
properties = data_point.properties,
|
||||
references = data_point.references,
|
||||
)
|
||||
else:
|
||||
return collection.data.insert(
|
||||
uuid = data_point.uuid,
|
||||
vector = data_point.vector,
|
||||
properties = data_point.properties,
|
||||
references = data_point.references,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.error("Error creating data points: %s", str(error))
|
||||
raise error
|
||||
|
|
@ -133,7 +141,7 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
await self.create_data_points(f"{index_name}_{index_property_name}", [
|
||||
IndexSchema(
|
||||
id = data_point.id,
|
||||
text = data_point.get_embeddable_data(),
|
||||
text = DataPoint.get_embeddable_data(data_point),
|
||||
) for data_point in data_points
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,11 @@ class DataPoint(BaseModel):
|
|||
# class Config:
|
||||
# underscore_attrs_are_private = True
|
||||
|
||||
def get_embeddable_data(self):
|
||||
if self._metadata and len(self._metadata["index_fields"]) > 0 \
|
||||
and hasattr(self, self._metadata["index_fields"][0]):
|
||||
attribute = getattr(self, self._metadata["index_fields"][0])
|
||||
@classmethod
|
||||
def get_embeddable_data(self, data_point):
|
||||
if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \
|
||||
and hasattr(data_point, data_point._metadata["index_fields"][0]):
|
||||
attribute = getattr(data_point, data_point._metadata["index_fields"][0])
|
||||
|
||||
if isinstance(attribute, str):
|
||||
return attribute.strip()
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User):
|
|||
destination = get_dlt_destination()
|
||||
|
||||
pipeline = dlt.pipeline(
|
||||
pipeline_name="file_load_from_filesystem",
|
||||
destination=destination,
|
||||
pipeline_name = "file_load_from_filesystem",
|
||||
destination = destination,
|
||||
)
|
||||
|
||||
@dlt.resource(standalone = True, merge_key = "id")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.engine import DataPoint
|
||||
|
||||
|
|
|
|||
4
poetry.lock
generated
4
poetry.lock
generated
|
|
@ -4659,8 +4659,8 @@ files = [
|
|||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.20.3", markers = "python_version < \"3.10\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""},
|
||||
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
|
|
@ -7840,4 +7840,4 @@ weaviate = ["weaviate-client"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9.0,<3.12"
|
||||
content-hash = "b63498e7aa23cfe29d8bea1fc29b0fe4a5f1a9e8ae5ec75d45b4bd20438e26f9"
|
||||
content-hash = "e2360f4be222743bb83b1e7316185c5f62bd73c0baaab3eee984e1c84f1cea65"
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ aiosqlite = "^0.20.0"
|
|||
pandas = "2.0.3"
|
||||
filetype = "^1.2.0"
|
||||
nltk = "^3.8.1"
|
||||
dlt = {extras = ["sqlalchemy"], version = "^1.3.0"}
|
||||
dlt = {extras = ["sqlalchemy"], version = "^1.4.1"}
|
||||
aiofiles = "^23.2.1"
|
||||
qdrant-client = "^1.9.0"
|
||||
graphistry = "^0.33.5"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue