-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fp8 all gather hack #1136
base: ngoyal_added_zero2_shard_modelparams_multiple_gpus
Are you sure you want to change the base?
Fp8 all gather hack #1136
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,8 +39,10 @@ | |
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.parameter import Parameter | ||
import transformer_engine.pytorch as te | ||
|
||
from fairscale.nn.misc import FlattenParamsWrapper | ||
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter | ||
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap | ||
from fairscale.utils.containers import apply_to_tensors | ||
from fairscale.utils.parallel import ( | ||
|
@@ -53,6 +55,8 @@ | |
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device | ||
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer | ||
from fairscale.utils.state_dict import replace_by_prefix_ | ||
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors | ||
from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager | ||
|
||
from . import fsdp_optim_utils as ou | ||
|
||
|
@@ -116,6 +120,14 @@ class OffloadConfig: | |
dir: Optional[str] = None | ||
|
||
|
||
def _is_te_module_with_weights(m: nn.Module) -> bool: | ||
return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)) | ||
|
||
|
||
def _is_fp8_dtype(dtype: torch.dtype) -> bool: | ||
return dtype in [torch.float8_e5m2, torch.float8_e4m3fn] | ||
|
||
|
||
class FullyShardedDataParallel(nn.Module): | ||
""" | ||
A wrapper for sharding Module parameters across data parallel workers. This | ||
|
@@ -455,9 +467,24 @@ def __init__( | |
non_flatten_params = params | ||
param_name_groups = [[n] for n in param_names] | ||
if self.flatten_parameters: | ||
to_be_flatten_params = [params] | ||
non_flatten_params = [] | ||
param_name_groups = [param_names] | ||
# don't flatten norm_weights since we need to handle them | ||
# separately during fp8 training | ||
to_be_flatten_params = [ | ||
[ | ||
params[i] | ||
for i in range(len(params)) | ||
if "norm_weight" not in param_names[i] | ||
] | ||
] | ||
non_flatten_params = [ | ||
params[i] | ||
for i in range(len(params)) | ||
if "norm_weight" in param_names[i] | ||
] | ||
param_name_groups = [ | ||
[n for n in param_names if "norm_weight" not in n], | ||
[n for n in param_names if "norm_weight" in n], | ||
] | ||
del param_names | ||
|
||
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( | ||
|
@@ -525,6 +552,9 @@ def __init__( | |
if self.zero2_process_group is not None: | ||
assert not self.move_params_to_cpu | ||
|
||
def _is_fp8_dtype(self) -> bool: | ||
return _is_fp8_dtype(self.compute_dtype) | ||
|
||
def _get_gradient_predivide_factor(self, world_size: int) -> float: | ||
factor: int = 1 | ||
while world_size % factor == 0 and world_size / factor > factor: | ||
|
@@ -654,7 +684,7 @@ def _cast_buffers( | |
@property | ||
def params_with_grad(self) -> List[Parameter]: | ||
"""[p for p in self.parameters() if p.grad is not None]""" | ||
return [p for p in self.parameters() if p.grad is not None] | ||
return [p for p in self.parameters() if p.grad is not None or getattr(p, "main_grad", None) is not None] | ||
|
||
@torch.no_grad() | ||
def clip_grad_norm_( | ||
|
@@ -757,7 +787,10 @@ def _shard_parameters_(self) -> None: | |
assert p.dtype == torch.float32 | ||
|
||
# If world_size is 1, then we all-reduce grads instead of sharding. | ||
p._is_sharded = self.world_size > 1 | ||
# An exception is norm weights during fp8 training. | ||
p._is_sharded = self.world_size > 1 and ( | ||
not self._is_fp8_dtype() or isinstance(p, FlatParameter) | ||
) | ||
p._orig_size = p.data.size() | ||
|
||
if not p._is_sharded: | ||
|
@@ -1141,7 +1174,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge | |
non_shared_params | ||
), f"{len(full_tensors)} vs. {len(non_shared_params)}" | ||
for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors): | ||
if not volatile: | ||
if not volatile and p._is_sharded: | ||
# Copy any changes made to the full params back into | ||
# the corresponding local shards. | ||
local_shard, _ = self._get_shard(full_tensor) | ||
|
@@ -1261,7 +1294,15 @@ def _init_param_attributes(self, p: Parameter) -> None: | |
# storage to size 0 at init (here) and re-materialize (by copying | ||
# from _fp32_shard) as needed. If offloading params to CPU, the | ||
# dtype of the fp16 shard will depend on the *`compute_dtype`*. | ||
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) | ||
if self._is_fp8_dtype() and not isinstance(p, FlatParameter): | ||
# assume non flattened are precision critical like norm | ||
assert not p._is_sharded | ||
dtype = torch.bfloat16 | ||
else: | ||
dtype = self.compute_dtype | ||
p._fp16_shard = torch.zeros_like( | ||
p._fp32_shard, device=self.compute_device, dtype=dtype | ||
) | ||
free_storage_(p._fp16_shard) | ||
|
||
if self.mixed_precision: | ||
|
@@ -1279,8 +1320,16 @@ def _init_param_attributes(self, p: Parameter) -> None: | |
# world_size, although these padding elements will be removed before the | ||
# relevant computation. | ||
if p._is_sharded: | ||
if self._is_fp8_dtype() and not isinstance(p, FlatParameter): | ||
# assume non flattened are precision critical like norm | ||
assert not p._is_sharded | ||
dtype = torch.bfloat16 | ||
else: | ||
dtype = self.compute_dtype | ||
p._full_param_padded = torch.zeros( | ||
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype | ||
p.data.numel() * self.world_size, | ||
device=self.compute_device, | ||
dtype=dtype, | ||
) | ||
free_storage_(p._full_param_padded) | ||
|
||
|
@@ -1393,20 +1442,66 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |
|
||
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for | ||
# the conversion). | ||
is_bf16 = self.compute_dtype == torch.bfloat16 | ||
is_bf16 = self.compute_dtype in [ | ||
torch.bfloat16, | ||
torch.float8_e5m2, | ||
torch.float8_e4m3fn, | ||
] | ||
if self._is_root and self.mixed_precision: | ||
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) | ||
|
||
just_added_to_fsdp_forward_ordering = False | ||
if self not in self._fsdp_forward_ordering: | ||
self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering) | ||
self._fsdp_forward_ordering.append(self) | ||
just_added_to_fsdp_forward_ordering = True | ||
|
||
# If enabled, convert the input to FP32 if we are in full precision. | ||
# no_grad is not used because the input might be for a non-root instance, | ||
# which mean autograd needs to go through the conversion. | ||
if self.force_input_to_fp32 and not self.mixed_precision: | ||
args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs) | ||
|
||
# need to use fp32_to_fp16 stream since _cast_fp32_param_shards_to_fp16 | ||
# depends on this block. | ||
with torch.no_grad(), torch.cuda.stream(self._streams["fp32_to_fp16"]): | ||
# Collect parameters to update scale/scale_inv before we | ||
# _cast_fp32_param_shards_to_fp16 that uses fp8 scale to quantize | ||
# before all-gather. | ||
# These include params we prefetch all-gather. | ||
params = [] | ||
if self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1: | ||
if self._my_fsdp_instance_idx == 0 and self._is_fp8_dtype(): | ||
# The first FSDP instance didn't have chance to prefetch | ||
params = self.params | ||
if self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._is_fp8_dtype(): | ||
# FSDP instance we'll prefetch all-gather | ||
params.extend(self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1].params) | ||
elif just_added_to_fsdp_forward_ordering: | ||
# In the first iteration, we didn't have chance to record | ||
# fsdp_instance_idx to prefetch | ||
if self._is_fp8_dtype(): | ||
params = self.params | ||
|
||
for p in params: | ||
if not isinstance(p, FlatParameter): | ||
continue | ||
d = {info[0]: info[1] for info in p._param_infos} | ||
for n, m in d.items(): | ||
# Previous iteration was grad_enabled | ||
if m.fp8_meta.get("update_amax_and_scale_fwd", False): | ||
if m.fp8_meta["recipe"].reduce_amax: | ||
FP8GlobalStateManager.copy_amax_from_global_buffer( | ||
m.fp8_meta, forward=True | ||
) | ||
# FIXME update_weight_scale_inv is only True for the first micro-batch | ||
amax_and_scale_update(m.fp8_meta, True) | ||
FP8GlobalStateManager.set_amax_buffer_key_deletion( | ||
m.fp8_meta, forward=True | ||
) | ||
else: | ||
amax_and_scale_update(m.fp8_meta, True) | ||
|
||
# All-gather full parameters. This will also transfer FP32 parameters to | ||
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). | ||
self._rebuild_full_params() | ||
|
@@ -1648,7 +1743,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
# then subsequent hook callbacks will see POST state. | ||
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) | ||
self.training_state = TrainingState.BACKWARD_POST | ||
if param.grad is None: | ||
grad_or_main_grad = ( | ||
param.main_grad if getattr(param, "main_grad", None) is not None else param.grad | ||
) | ||
if grad_or_main_grad is None: | ||
return | ||
|
||
if hasattr(param, "_linked_param"): | ||
|
@@ -1661,8 +1759,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared: | ||
param = param._linked_param | ||
|
||
assert param.grad is not None, param.shape | ||
if param.grad.requires_grad: | ||
assert grad_or_main_grad is not None, param.shape | ||
if grad_or_main_grad.requires_grad: | ||
raise RuntimeError("FSDP only works with gradients that don't require gradients") | ||
|
||
if self._require_backward_grad_sync or self.reshard_after_forward: | ||
|
@@ -1689,23 +1787,32 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
# reductions in post_backward stream. | ||
self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) | ||
with torch.cuda.stream(self._streams["post_backward"]): | ||
orig_grad_data = param.grad.data | ||
orig_grad_data = grad_or_main_grad | ||
|
||
if self.mixed_precision and self.fp32_reduce_scatter: | ||
# Cast grad to FP32. | ||
param.grad.data = param.grad.data.to(param.dtype) | ||
if self.mixed_precision: | ||
if self.fp32_reduce_scatter: | ||
# Cast grad to FP32. | ||
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype) | ||
elif self._is_fp8_dtype(): | ||
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently this is not working with the latest FP8 wgrad ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This meant to be for future work when we have fp8 reduce-scatter. I'll update the comment. |
||
grad_or_main_grad.data = grad_or_main_grad.to(torch.bfloat16) | ||
|
||
if self.gradient_predivide_factor > 1: | ||
# Average grad by world_size for consistency with PyTorch DDP. | ||
param.grad.data.div_(self.gradient_predivide_factor) | ||
grad_or_main_grad.div_(self.gradient_predivide_factor) | ||
|
||
# logging.info(f"{torch.distributed.get_rank()=} {param._is_sharded=}") | ||
if param._is_sharded: | ||
assert self._reducer is not None | ||
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into | ||
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple | ||
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't | ||
# matter, neglecting rounding. | ||
grad = param.grad.data | ||
if hasattr(param, "main_grad"): | ||
grad = param.main_grad | ||
param.main_grad = None | ||
else: | ||
grad = param.grad | ||
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. | ||
# | ||
# The effect on memory consumption is not usually significant. No extra memory is allocated if this | ||
|
@@ -1727,8 +1834,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
# Currently the only way for _is_sharded to be False is if | ||
# world_size == 1. This could be relaxed in the future, in which | ||
# case grads should be all-reduced here. | ||
assert self.world_size == 1 | ||
self._post_reduction_hook(param, param.grad.data) | ||
# assert self.world_size == 1 | ||
self._post_reduction_hook(param, grad_or_main_grad) | ||
|
||
# After _post_backward_hook returns, orig_grad_data will eventually | ||
# go out of scope, at which point it could otherwise be freed for | ||
|
@@ -1840,6 +1947,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: | |
p.device == p._saved_grad_shard.device, | ||
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}", | ||
) | ||
assert p.shape == p._saved_grad_shard.shape | ||
assert p.dtype == p._saved_grad_shard.dtype | ||
assert getattr(p, "main_grad", None) is None | ||
p.grad = p._saved_grad_shard | ||
|
||
if hasattr(p, "_saved_grad_shard"): | ||
|
@@ -2166,6 +2276,15 @@ def _prep_grads_for_backward(self) -> None: | |
right shape, device, accumulated values, etc. | ||
""" | ||
for p in self.params: | ||
if isinstance(p, FlatParameter) and all( | ||
_is_te_module_with_weights(info[1]) for info in p._param_infos | ||
): | ||
if getattr(p, "main_grad", None) is None: | ||
p.main_grad = torch.empty_like(p, dtype=torch.float) | ||
main_grad_views = p.get_param_views(p.main_grad) | ||
for (_, m, n), main_grad in zip(p._param_infos, main_grad_views): | ||
getattr(m, n).main_grad = main_grad | ||
|
||
if p.grad is not None: | ||
if p.grad.device != p.data.device: | ||
p.grad = None | ||
|
@@ -2265,8 +2384,8 @@ def local_metadata_dict(self) -> Dict[str, Any]: | |
backing_param_name = m.module.flat_param_names[i] | ||
names, shapes, numels = m.module.metadata(i) | ||
else: | ||
assert len(m._param_name_groups[i]) == 1 | ||
backing_param_name = m._param_name_groups[i][0] | ||
# assert len(m._param_name_groups[i]) == 1 | ||
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to make sure checkpointing works properly with this. |
||
names = [backing_param_name] | ||
shapes = [p._orig_size] | ||
numels = [p._orig_size.numel()] | ||
|
@@ -2382,12 +2501,50 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No | |
for p in params: | ||
assert p._fp16_shard is not None | ||
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) | ||
p._fp16_shard.copy_( | ||
# If move_params_to_cpu is True, this will be non-blocking | ||
# because _fp32_shard is pinned, otherwise it's a no-op. | ||
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) | ||
) | ||
p.data = p._fp16_shard | ||
if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype): | ||
# fp8 quantization | ||
assert isinstance(p, FlatParameter) | ||
assert len(p._param_infos) == len(p._param_numels) | ||
numel_per_shard = p.numel() | ||
offset = -numel_per_shard * self.rank | ||
for i in range(len(p._param_infos)): | ||
_, m, n = p._param_infos[i] | ||
numel = p._param_numels[i] | ||
if offset + numel <= 0 or offset >= numel_per_shard: | ||
offset += numel | ||
continue | ||
assert _is_te_module_with_weights(m) | ||
fp8_dtype_forward = te.fp8.get_fp8_te_dtype( | ||
m.fp8_meta["recipe"], fprop_tensor=True | ||
) | ||
if not m.fp8_initialized: | ||
m.fp8_init( | ||
num_gemms=2 if isinstance(m, te.LayerNormMLP) else 1 | ||
) | ||
begin = max(offset, 0) | ||
end = min(offset + numel, numel_per_shard) | ||
cast_to_fp8( | ||
p._fp32_shard[begin:end], | ||
m.fp8_meta["scaling_fwd"], | ||
FP8FwdTensors.GEMM2_WEIGHT | ||
if n == "fc2_weight" | ||
else FP8FwdTensors.GEMM1_WEIGHT, | ||
fp8_dtype_forward, | ||
out=p._fp16_shard[begin:end], | ||
) | ||
offset += numel | ||
p.data = p._fp16_shard.view( | ||
torch.float8_e4m3fn | ||
if fp8_dtype_forward == DType.kFloat8E4M3 | ||
else torch.float8_e5m2 | ||
) | ||
else: | ||
p._fp16_shard.copy_( | ||
# If move_params_to_cpu is True, this will be non-blocking | ||
# because _fp32_shard is pinned, otherwise it's a no-op. | ||
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) | ||
) | ||
p.data = p._fp16_shard | ||
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) | ||
|
||
@torch.no_grad() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is_bf16_or_fp8