From 0be2e5e57fda5e025e39b95049e8836ce408c8e2 Mon Sep 17 00:00:00 2001 From: vedanuj Date: Sun, 8 Oct 2023 12:40:15 -0700 Subject: [PATCH] guard main_grad None --- fairscale/nn/misc/flatten_params_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 91c4f59b9..8f0c8f341 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -370,7 +370,7 @@ def _unflatten_params_as_views(self) -> None: """ assert self.is_flattened for p in self.flat_params: - if not hasattr(p, 'main_grad') or p.main_grad.shape != p.shape: + if getattr(p, 'main_grad', None) is None or p.main_grad.shape != p.shape: p.main_grad = torch.zeros_like(p, dtype=torch.float32) ps, ps_main_grad = self.get_param_views() param_views = []