From 256e403d027c63422742f01bc01fcf9ced7acc4f Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Wed, 11 Dec 2024 13:00:06 +0000 Subject: [PATCH] Add weights_only parameter to default_deserialize_torch_model for enhanced security (#950) --- thinc/shims/pytorch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/thinc/shims/pytorch.py b/thinc/shims/pytorch.py index 3d18daf0f..0efe47a01 100644 --- a/thinc/shims/pytorch.py +++ b/thinc/shims/pytorch.py @@ -233,7 +233,7 @@ def default_serialize_torch_model(model: Any) -> bytes: def default_deserialize_torch_model( - model: Any, state_bytes: bytes, device: "torch.device" + model: Any, state_bytes: bytes, device: "torch.device", weights_only: bool = True ) -> Any: """Deserializes the parameters of the wrapped PyTorch model and moves it to the specified device. @@ -244,12 +244,15 @@ def default_deserialize_torch_model( Serialized parameters as a byte stream. device: PyTorch device to which the model is bound. + weights_only: + Whether to only load the model's weights (default: True). Setting this + to True enhances security and avoids loading arbitrary objects. Returns: The deserialized model. """ filelike = BytesIO(state_bytes) filelike.seek(0) - model.load_state_dict(torch.load(filelike, map_location=device)) + state_dict = torch.load(filelike, map_location=device, weights_only=weights_only) model.to(device) return model