diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index be874e125..bf3e6c018 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1386,8 +1386,9 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # For root and mixed precision, we convert the input to FP16 (no_grad is needed for # the conversion). + is_bf16 = self.compute_dtype == torch.bfloat16 if self._is_root and self.mixed_precision: - args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs) + args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) if self not in self._fsdp_forward_ordering: self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering) @@ -1397,7 +1398,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # no_grad is not used because the input might be for a non-root instance, # which mean autograd needs to go through the conversion. if self.force_input_to_fp32 and not self.mixed_precision: - args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs) + args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs) # All-gather full parameters. This will also transfer FP32 parameters to # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). @@ -2054,6 +2055,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: if params is None: params = self.params self.has_full_params = False + current_stream = torch.cuda.current_stream() for p in params: if not p._is_sharded: # e.g., world_size == 1 @@ -2182,7 +2184,6 @@ def consolidate_shard_weights( for n, t, s in zip(names, full_param.split(numels), shapes): out_state_dict_key = ".".join([fsdp_path, n]) if fsdp_path else n consolidated_weights[out_state_dict_key] = t.view(s) - # copy shared parameters for src_path, dest_path in metadata["shared_param_info"]: consolidated_weights[dest_path] = consolidated_weights[src_path] @@ -2462,7 +2463,7 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device: return torch.device("cuda") -def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: +def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, is_bf16: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]: """ Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not. We also retain the requires_grad flag so that casting doesn't affect the autograd graph. @@ -2470,7 +2471,10 @@ def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **k def fn_fp16(x: torch.Tensor) -> torch.Tensor: if x.dtype is torch.float32: - y = x.half() + if is_bf16: + y = x.bfloat16() + else: + y = x.half() if x.is_leaf: y.requires_grad = x.requires_grad return y diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 0cbdce2bb..38265dd2b 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -372,6 +372,7 @@ def _unflatten_params_as_views(self) -> None: ps = self.get_param_views() param_views = [] for (_, m, n), p in zip(self._param_infos, ps): + setattr(p, '_fsdp_weight', True) setattr(m, n, p) # This will set as plain attr param_views.append(p) @@ -382,6 +383,7 @@ def _unflatten_params_as_views(self) -> None: for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos: setattr(m, n, getattr(shared_m, shared_n)) + @contextmanager def unflatten_params(self, flat_params: Optional[List[Tensor]] = None) -> Generator: """ diff --git a/tests/nn/data_parallel/test_fsdp.py b/tests/nn/data_parallel/test_fsdp.py index 7313bf262..06f96a9db 100644 --- a/tests/nn/data_parallel/test_fsdp.py +++ b/tests/nn/data_parallel/test_fsdp.py @@ -211,6 +211,17 @@ def test_mixed_precision_autocast_fp32_compute(self): expected_buffer_type=torch.float32, ) + def test_mixed_precision_bfloat16(self): + self._spawn_test_case( + {"mixed_precision": True, "compute_dtype": torch.bfloat16}, + True, # autocast enabled + torch.bfloat16, # expected_input_dtype + torch.bfloat16, # expected_param_dtype + torch.float32, # expected_loss_dtype + torch.bfloat16, # expected_reduce_dtype + expected_buffer_type=torch.float32, + ) + def test_fp32_reduce_scatter(self): self._spawn_test_case( {"mixed_precision": True, "fp32_reduce_scatter": True},