diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 5a1d7be35..37d340004 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -87,7 +87,7 @@ class LanceDBAdapter(VectorDBInterface): collection = await connection.open_table(collection_name) data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + [DataPoint.get_embeddable_data(data_point) for data_point in data_points] ) IdType = TypeVar("IdType") diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 27db2c276..f2e5ee369 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -102,7 +102,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): ) data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + [DataPoint.get_embeddable_data(data_point) for data_point in data_points] ) vector_size = self.embedding_engine.get_vector_size() @@ -143,7 +143,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = data_point.get_embeddable_data(), + text = DataPoint.get_embeddable_data(data_point), ) for data_point in data_points ]) diff --git a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py index dc33e98ae..d5d2a1a5c 100644 --- a/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py +++ b/cognee/infrastructure/databases/vector/qdrant/QDrantAdapter.py @@ -102,7 +102,9 @@ class QDrantAdapter(VectorDBInterface): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): client = self.get_qdrant_client() - data_vectors = await self.embed_data([data_point.get_embeddable_data() for data_point in data_points]) + data_vectors = await self.embed_data([ + DataPoint.get_embeddable_data(data_point) for data_point in data_points + ]) def convert_to_qdrant_point(data_point: DataPoint): return models.PointStruct( diff --git a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py index 0c97dc9a8..c16f765b0 100644 --- a/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py +++ b/cognee/infrastructure/databases/vector/weaviate_db/WeaviateAdapter.py @@ -83,7 +83,7 @@ class WeaviateAdapter(VectorDBInterface): from weaviate.classes.data import DataObject data_vectors = await self.embed_data( - [data_point.get_embeddable_data() for data_point in data_points] + [DataPoint.get_embeddable_data(data_point) for data_point in data_points] ) def convert_to_weaviate_data_points(data_point: DataPoint): @@ -116,12 +116,20 @@ class WeaviateAdapter(VectorDBInterface): ) else: data_point: DataObject = data_points[0] - return collection.data.update( - uuid = data_point.uuid, - vector = data_point.vector, - properties = data_point.properties, - references = data_point.references, - ) + if collection.data.exists(data_point.uuid): + return collection.data.update( + uuid = data_point.uuid, + vector = data_point.vector, + properties = data_point.properties, + references = data_point.references, + ) + else: + return collection.data.insert( + uuid = data_point.uuid, + vector = data_point.vector, + properties = data_point.properties, + references = data_point.references, + ) except Exception as error: logger.error("Error creating data points: %s", str(error)) raise error @@ -133,7 +141,7 @@ class WeaviateAdapter(VectorDBInterface): await self.create_data_points(f"{index_name}_{index_property_name}", [ IndexSchema( id = data_point.id, - text = data_point.get_embeddable_data(), + text = DataPoint.get_embeddable_data(data_point), ) for data_point in data_points ]) diff --git a/cognee/infrastructure/engine/models/DataPoint.py b/cognee/infrastructure/engine/models/DataPoint.py index b76971f34..abb924f2f 100644 --- a/cognee/infrastructure/engine/models/DataPoint.py +++ b/cognee/infrastructure/engine/models/DataPoint.py @@ -19,10 +19,11 @@ class DataPoint(BaseModel): # class Config: # underscore_attrs_are_private = True - def get_embeddable_data(self): - if self._metadata and len(self._metadata["index_fields"]) > 0 \ - and hasattr(self, self._metadata["index_fields"][0]): - attribute = getattr(self, self._metadata["index_fields"][0]) + @classmethod + def get_embeddable_data(self, data_point): + if data_point._metadata and len(data_point._metadata["index_fields"]) > 0 \ + and hasattr(data_point, data_point._metadata["index_fields"][0]): + attribute = getattr(data_point, data_point._metadata["index_fields"][0]) if isinstance(attribute, str): return attribute.strip() diff --git a/cognee/tasks/ingestion/ingest_data_with_metadata.py b/cognee/tasks/ingestion/ingest_data_with_metadata.py index 0c17b71f5..abd3c9f94 100644 --- a/cognee/tasks/ingestion/ingest_data_with_metadata.py +++ b/cognee/tasks/ingestion/ingest_data_with_metadata.py @@ -20,8 +20,8 @@ async def ingest_data_with_metadata(data: Any, dataset_name: str, user: User): destination = get_dlt_destination() pipeline = dlt.pipeline( - pipeline_name="file_load_from_filesystem", - destination=destination, + pipeline_name = "file_load_from_filesystem", + destination = destination, ) @dlt.resource(standalone = True, merge_key = "id") diff --git a/cognee/tasks/storage/index_data_points.py b/cognee/tasks/storage/index_data_points.py index 58e4f096d..786168b58 100644 --- a/cognee/tasks/storage/index_data_points.py +++ b/cognee/tasks/storage/index_data_points.py @@ -1,4 +1,3 @@ -import asyncio from cognee.infrastructure.databases.vector import get_vector_engine from cognee.infrastructure.engine import DataPoint diff --git a/poetry.lock b/poetry.lock index ed9d932b0..9b309de51 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4659,8 +4659,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -7840,4 +7840,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "b63498e7aa23cfe29d8bea1fc29b0fe4a5f1a9e8ae5ec75d45b4bd20438e26f9" +content-hash = "e2360f4be222743bb83b1e7316185c5f62bd73c0baaab3eee984e1c84f1cea65" diff --git a/pyproject.toml b/pyproject.toml index 23aba656b..46d0a89a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ aiosqlite = "^0.20.0" pandas = "2.0.3" filetype = "^1.2.0" nltk = "^3.8.1" -dlt = {extras = ["sqlalchemy"], version = "^1.3.0"} +dlt = {extras = ["sqlalchemy"], version = "^1.4.1"} aiofiles = "^23.2.1" qdrant-client = "^1.9.0" graphistry = "^0.33.5"