Compare commits

...
Sign in to create a new pull request.

1 commit

View file

@ -176,42 +176,42 @@ class DataPoint(BaseModel):
"""
return self.model_validate_json(json_str)
# Pickle Serialization
# Pickle Serialization (safe patch: use JSON-under-bytes, not pickle)
def to_pickle(self) -> bytes:
"""
Serialize the DataPoint instance to a byte format for pickling.
Serialize the DataPoint instance to a byte format for persistence or transmission.
This method uses the built-in Python pickle module to convert the instance into a byte
stream for persistence or transmission.
For security, this implementation uses JSON for serialization and encodes as UTF-8 bytes.
Returns:
--------
- bytes: The pickled byte representation of the DataPoint instance.
- bytes: The serialized byte representation of the DataPoint instance.
"""
return pickle.dumps(self.dict())
json_str = self.to_json()
return json_str.encode('utf-8')
@classmethod
def from_pickle(self, pickled_data: bytes):
"""
Deserialize a DataPoint instance from a pickled byte stream.
Deserialize a DataPoint instance from a serialized byte stream.
The method converts the byte stream back into a DataPoint instance by loading the data
and validating it through the model's constructor.
For security, this implementation expects UTF-8 encoded JSON data.
Parameters:
-----------
- pickled_data (bytes): The bytes representation of a pickled DataPoint instance to
- pickled_data (bytes): The bytes representation of a serialized DataPoint instance to
be deserialized.
Returns:
--------
A new DataPoint instance created from the pickled data.
A new DataPoint instance created from the serialized data.
"""
data = pickle.loads(pickled_data)
return self(**data)
# Do NOT use pickle.loads.
json_str = pickled_data.decode('utf-8')
return self.from_json(json_str)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""
@ -252,4 +252,4 @@ class DataPoint(BaseModel):
- 'DataPoint': A new DataPoint instance constructed from the provided dictionary
data.
"""
return cls.model_validate(data)
return cls.model_validate(data)