diff --git a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py index 789ed2e84..4badb0a97 100644 --- a/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py +++ b/cognee/infrastructure/databases/vector/pgvector/PGVectorAdapter.py @@ -124,16 +124,32 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface): self.payload = payload self.vector = vector - pgvector_data_points = [ - PGVectorDataPoint( - id=data_point.id, - vector=data_vectors[data_index], - payload=serialize_data(data_point.model_dump()), - ) - for (data_index, data_point) in enumerate(data_points) - ] - async with self.get_async_session() as session: + pgvector_data_points = [] + + for data_index, data_point in enumerate(data_points): + # Check to see if data should be updated or a new data item should be created + data_point_db = ( + await session.execute( + select(PGVectorDataPoint).filter(PGVectorDataPoint.id == data_point.id) + ) + ).scalar_one_or_none() + + # If data point exists update it, if not create a new one + if data_point_db: + data_point_db.id = data_point.id + data_point_db.vector = data_vectors[data_index] + data_point_db.payload = serialize_data(data_point.model_dump()) + pgvector_data_points.append(data_point_db) + else: + pgvector_data_points.append( + PGVectorDataPoint( + id=data_point.id, + vector=data_vectors[data_index], + payload=serialize_data(data_point.model_dump()), + ) + ) + session.add_all(pgvector_data_points) await session.commit()