From 022479730cb3849faec076ee4e8b47dccec5a677 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Sun, 17 Sep 2023 11:31:24 -0700 Subject: [PATCH 1/3] fp8 allgather --- .../fully_sharded_data_parallel.py | 118 +++++++++++++++--- 1 file changed, 103 insertions(+), 15 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 0cff54acb..6d7c8db26 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -39,8 +39,10 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter +import transformer_engine.pytorch as te from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.nn.misc.flatten_params_wrapper import FlatParameter from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import ( @@ -53,6 +55,11 @@ from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ +from transformer_engine.pytorch.cpp_extensions import ( + cast_to_fp8, + DType, + FP8FwdTensors, +) from . import fsdp_optim_utils as ou @@ -455,9 +462,20 @@ def __init__( non_flatten_params = params param_name_groups = [[n] for n in param_names] if self.flatten_parameters: - to_be_flatten_params = [params] - non_flatten_params = [] - param_name_groups = [param_names] + to_be_flatten_params = [ + [ + params[i] + for i in range(len(params)) + if "norm_weight" not in param_names[i] + ] + ] + non_flatten_params = [ + params[i] for i in range(len(params)) if "norm_weight" in param_names[i] + ] + param_name_groups = [ + [n for n in param_names if "norm_weight" not in n], + [n for n in param_names if "norm_weight" in n], + ] del param_names self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( @@ -1261,7 +1279,17 @@ def _init_param_attributes(self, p: Parameter) -> None: # storage to size 0 at init (here) and re-materialize (by copying # from _fp32_shard) as needed. If offloading params to CPU, the # dtype of the fp16 shard will depend on the *`compute_dtype`*. - p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) + if self.compute_dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and not isinstance(p, FlatParameter): + # assume non flattened are precision critical like norm + dtype = torch.bfloat16 + else: + dtype = self.compute_dtype + p._fp16_shard = torch.zeros_like( + p._fp32_shard, device=self.compute_device, dtype=dtype + ) free_storage_(p._fp16_shard) if self.mixed_precision: @@ -1279,8 +1307,18 @@ def _init_param_attributes(self, p: Parameter) -> None: # world_size, although these padding elements will be removed before the # relevant computation. if p._is_sharded: + if self.compute_dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and not isinstance(p, FlatParameter): + # assume non flattened are precision critical like norm + dtype = torch.bfloat16 + else: + dtype = self.compute_dtype p._full_param_padded = torch.zeros( - p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype + p.data.numel() * self.world_size, + device=self.compute_device, + dtype=dtype, ) free_storage_(p._full_param_padded) @@ -1393,7 +1431,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # For root and mixed precision, we convert the input to FP16 (no_grad is needed for # the conversion). - is_bf16 = self.compute_dtype == torch.bfloat16 + is_bf16 = self.compute_dtype in [ + torch.bfloat16, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] if self._is_root and self.mixed_precision: args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) @@ -1691,9 +1733,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: with torch.cuda.stream(self._streams["post_backward"]): orig_grad_data = param.grad.data - if self.mixed_precision and self.fp32_reduce_scatter: - # Cast grad to FP32. - param.grad.data = param.grad.data.to(param.dtype) + if self.mixed_precision: + if self.fp32_reduce_scatter: + # Cast grad to FP32. + param.grad.data = param.grad.data.to(param.dtype) + elif self.compute_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + # Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) + param.grad.data = param.grad.data.to(torch.bfloat16) if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. @@ -2382,12 +2428,54 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - p._fp16_shard.copy_( - # If move_params_to_cpu is True, this will be non-blocking - # because _fp32_shard is pinned, otherwise it's a no-op. - p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) - ) - p.data = p._fp16_shard + if self.compute_dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + ] and p._fp16_shard.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + assert isinstance(p, FlatParameter) + assert len(p._param_infos) == len(p._param_numels) + numel_per_shard = p.numel() + offset = -numel_per_shard * self.rank + for i in range(len(p._param_infos)): + _, m, n = p._param_infos[i] + numel = p._param_numels[i] + if offset + numel <= 0 or offset >= numel_per_shard: + offset += numel + continue + assert isinstance( + m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP) + ) + fp8_dtype_forward = te.fp8.get_fp8_te_dtype( + m.fp8_meta["recipe"], fprop_tensor=True + ) + if not m.fp8_initialized: + m.fp8_init( + num_gemms=2 if isinstance(m, te.LayerNormMLP) else 1 + ) + begin = max(offset, 0) + end = min(offset + numel, numel_per_shard) + cast_to_fp8( + p._fp32_shard[begin:end], + m.fp8_meta["scaling_fwd"], + FP8FwdTensors.GEMM2_WEIGHT + if n == "fc2_weight" + else FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + out=p._fp16_shard[begin:end], + ) + offset += numel + p.data = p._fp16_shard.view( + torch.float8_e4m3fn + if fp8_dtype_forward == DType.kFloat8E4M3 + else torch.float8_e5m2 + ) + else: + p._fp16_shard.copy_( + # If move_params_to_cpu is True, this will be non-blocking + # because _fp32_shard is pinned, otherwise it's a no-op. + p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) + ) + p.data = p._fp16_shard torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) @torch.no_grad() From db6a1c7a255d56b88e3416c1203f6ec4e4d489b8 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Thu, 21 Sep 2023 18:25:38 -0700 Subject: [PATCH 2/3] don't shard norm weights --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 6d7c8db26..f1aa72842 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -775,7 +775,7 @@ def _shard_parameters_(self) -> None: assert p.dtype == torch.float32 # If world_size is 1, then we all-reduce grads instead of sharding. - p._is_sharded = self.world_size > 1 + p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter) p._orig_size = p.data.size() if not p._is_sharded: @@ -1159,7 +1159,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge non_shared_params ), f"{len(full_tensors)} vs. {len(non_shared_params)}" for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors): - if not volatile: + if not volatile and p._is_sharded: # Copy any changes made to the full params back into # the corresponding local shards. local_shard, _ = self._get_shard(full_tensor) @@ -1773,7 +1773,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Currently the only way for _is_sharded to be False is if # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. - assert self.world_size == 1 + # assert self.world_size == 1 self._post_reduction_hook(param, param.grad.data) # After _post_backward_hook returns, orig_grad_data will eventually From c0f4b97b8df697134d1c53730957868e84d89d0b Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 4 Oct 2023 16:03:01 -0700 Subject: [PATCH 3/3] use main_grad for higher precision gradient accumulation; update amax during post_backward_hook --- .../fully_sharded_data_parallel.py | 139 +++++++++++++----- 1 file changed, 104 insertions(+), 35 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index f1aa72842..35d691343 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -55,11 +55,8 @@ from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ -from transformer_engine.pytorch.cpp_extensions import ( - cast_to_fp8, - DType, - FP8FwdTensors, -) +from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors +from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager from . import fsdp_optim_utils as ou @@ -123,6 +120,14 @@ class OffloadConfig: dir: Optional[str] = None +def _is_te_module_with_weights(m: nn.Module) -> bool: + return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)) + + +def _is_fp8_dtype(dtype: torch.dtype) -> bool: + return dtype in [torch.float8_e5m2, torch.float8_e4m3fn] + + class FullyShardedDataParallel(nn.Module): """ A wrapper for sharding Module parameters across data parallel workers. This @@ -462,6 +467,8 @@ def __init__( non_flatten_params = params param_name_groups = [[n] for n in param_names] if self.flatten_parameters: + # don't flatten norm_weights since we need to handle them + # separately during fp8 training to_be_flatten_params = [ [ params[i] @@ -470,7 +477,9 @@ def __init__( ] ] non_flatten_params = [ - params[i] for i in range(len(params)) if "norm_weight" in param_names[i] + params[i] + for i in range(len(params)) + if "norm_weight" in param_names[i] ] param_name_groups = [ [n for n in param_names if "norm_weight" not in n], @@ -543,6 +552,9 @@ def __init__( if self.zero2_process_group is not None: assert not self.move_params_to_cpu + def _is_fp8_dtype(self) -> bool: + return _is_fp8_dtype(self.compute_dtype) + def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -672,7 +684,7 @@ def _cast_buffers( @property def params_with_grad(self) -> List[Parameter]: """[p for p in self.parameters() if p.grad is not None]""" - return [p for p in self.parameters() if p.grad is not None] + return [p for p in self.parameters() if p.grad is not None or getattr(p, "main_grad", None) is not None] @torch.no_grad() def clip_grad_norm_( @@ -775,7 +787,10 @@ def _shard_parameters_(self) -> None: assert p.dtype == torch.float32 # If world_size is 1, then we all-reduce grads instead of sharding. - p._is_sharded = self.world_size > 1 and isinstance(p, FlatParameter) + # An exception is norm weights during fp8 training. + p._is_sharded = self.world_size > 1 and ( + not self._is_fp8_dtype() or isinstance(p, FlatParameter) + ) p._orig_size = p.data.size() if not p._is_sharded: @@ -1279,11 +1294,9 @@ def _init_param_attributes(self, p: Parameter) -> None: # storage to size 0 at init (here) and re-materialize (by copying # from _fp32_shard) as needed. If offloading params to CPU, the # dtype of the fp16 shard will depend on the *`compute_dtype`*. - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and not isinstance(p, FlatParameter): + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): # assume non flattened are precision critical like norm + assert not p._is_sharded dtype = torch.bfloat16 else: dtype = self.compute_dtype @@ -1307,11 +1320,9 @@ def _init_param_attributes(self, p: Parameter) -> None: # world_size, although these padding elements will be removed before the # relevant computation. if p._is_sharded: - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and not isinstance(p, FlatParameter): + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): # assume non flattened are precision critical like norm + assert not p._is_sharded dtype = torch.bfloat16 else: dtype = self.compute_dtype @@ -1439,9 +1450,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self._is_root and self.mixed_precision: args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) + just_added_to_fsdp_forward_ordering = False if self not in self._fsdp_forward_ordering: self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering) self._fsdp_forward_ordering.append(self) + just_added_to_fsdp_forward_ordering = True # If enabled, convert the input to FP32 if we are in full precision. # no_grad is not used because the input might be for a non-root instance, @@ -1449,6 +1462,46 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self.force_input_to_fp32 and not self.mixed_precision: args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs) + # need to use fp32_to_fp16 stream since _cast_fp32_param_shards_to_fp16 + # depends on this block. + with torch.no_grad(), torch.cuda.stream(self._streams["fp32_to_fp16"]): + # Collect parameters to update scale/scale_inv before we + # _cast_fp32_param_shards_to_fp16 that uses fp8 scale to quantize + # before all-gather. + # These include params we prefetch all-gather. + params = [] + if self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1: + if self._my_fsdp_instance_idx == 0 and self._is_fp8_dtype(): + # The first FSDP instance didn't have chance to prefetch + params = self.params + if self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._is_fp8_dtype(): + # FSDP instance we'll prefetch all-gather + params.extend(self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1].params) + elif just_added_to_fsdp_forward_ordering: + # In the first iteration, we didn't have chance to record + # fsdp_instance_idx to prefetch + if self._is_fp8_dtype(): + params = self.params + + for p in params: + if not isinstance(p, FlatParameter): + continue + d = {info[0]: info[1] for info in p._param_infos} + for n, m in d.items(): + # Previous iteration was grad_enabled + if m.fp8_meta.get("update_amax_and_scale_fwd", False): + if m.fp8_meta["recipe"].reduce_amax: + FP8GlobalStateManager.copy_amax_from_global_buffer( + m.fp8_meta, forward=True + ) + # FIXME update_weight_scale_inv is only True for the first micro-batch + amax_and_scale_update(m.fp8_meta, True) + FP8GlobalStateManager.set_amax_buffer_key_deletion( + m.fp8_meta, forward=True + ) + else: + amax_and_scale_update(m.fp8_meta, True) + # All-gather full parameters. This will also transfer FP32 parameters to # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). self._rebuild_full_params() @@ -1690,7 +1743,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST - if param.grad is None: + grad_or_main_grad = ( + param.main_grad if getattr(param, "main_grad", None) is not None else param.grad + ) + if grad_or_main_grad is None: return if hasattr(param, "_linked_param"): @@ -1703,8 +1759,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared: param = param._linked_param - assert param.grad is not None, param.shape - if param.grad.requires_grad: + assert grad_or_main_grad is not None, param.shape + if grad_or_main_grad.requires_grad: raise RuntimeError("FSDP only works with gradients that don't require gradients") if self._require_backward_grad_sync or self.reshard_after_forward: @@ -1731,27 +1787,32 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data + orig_grad_data = grad_or_main_grad if self.mixed_precision: if self.fp32_reduce_scatter: # Cast grad to FP32. - param.grad.data = param.grad.data.to(param.dtype) - elif self.compute_dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + grad_or_main_grad.data = grad_or_main_grad.to(param.dtype) + elif self._is_fp8_dtype(): # Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) - param.grad.data = param.grad.data.to(torch.bfloat16) + grad_or_main_grad.data = grad_or_main_grad.to(torch.bfloat16) if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + grad_or_main_grad.div_(self.gradient_predivide_factor) + # logging.info(f"{torch.distributed.get_rank()=} {param._is_sharded=}") if param._is_sharded: assert self._reducer is not None # Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if hasattr(param, "main_grad"): + grad = param.main_grad + param.main_grad = None + else: + grad = param.grad # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this @@ -1774,7 +1835,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. # assert self.world_size == 1 - self._post_reduction_hook(param, param.grad.data) + self._post_reduction_hook(param, grad_or_main_grad) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for @@ -1886,6 +1947,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: p.device == p._saved_grad_shard.device, f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}", ) + assert p.shape == p._saved_grad_shard.shape + assert p.dtype == p._saved_grad_shard.dtype + assert getattr(p, "main_grad", None) is None p.grad = p._saved_grad_shard if hasattr(p, "_saved_grad_shard"): @@ -2212,6 +2276,15 @@ def _prep_grads_for_backward(self) -> None: right shape, device, accumulated values, etc. """ for p in self.params: + if isinstance(p, FlatParameter) and all( + _is_te_module_with_weights(info[1]) for info in p._param_infos + ): + if getattr(p, "main_grad", None) is None: + p.main_grad = torch.empty_like(p, dtype=torch.float) + main_grad_views = p.get_param_views(p.main_grad) + for (_, m, n), main_grad in zip(p._param_infos, main_grad_views): + getattr(m, n).main_grad = main_grad + if p.grad is not None: if p.grad.device != p.data.device: p.grad = None @@ -2311,8 +2384,8 @@ def local_metadata_dict(self) -> Dict[str, Any]: backing_param_name = m.module.flat_param_names[i] names, shapes, numels = m.module.metadata(i) else: - assert len(m._param_name_groups[i]) == 1 - backing_param_name = m._param_name_groups[i][0] + # assert len(m._param_name_groups[i]) == 1 + backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params] names = [backing_param_name] shapes = [p._orig_size] numels = [p._orig_size.numel()] @@ -2428,10 +2501,8 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - if self.compute_dtype in [ - torch.float8_e5m2, - torch.float8_e4m3fn, - ] and p._fp16_shard.dtype in [torch.float8_e5m2, torch.float8_e4m3fn]: + if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype): + # fp8 quantization assert isinstance(p, FlatParameter) assert len(p._param_infos) == len(p._param_numels) numel_per_shard = p.numel() @@ -2442,9 +2513,7 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No if offset + numel <= 0 or offset >= numel_per_shard: offset += numel continue - assert isinstance( - m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP) - ) + assert _is_te_module_with_weights(m) fp8_dtype_forward = te.fp8.get_fp8_te_dtype( m.fp8_meta["recipe"], fprop_tensor=True )