Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 all gather hack #1136

Open
wants to merge 3 commits into
base: ngoyal_added_zero2_shard_modelparams_multiple_gpus
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 185 additions & 28 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import transformer_engine.pytorch as te

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc.flatten_params_wrapper import FlatParameter
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
Expand All @@ -53,6 +55,8 @@
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors
from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager

from . import fsdp_optim_utils as ou

Expand Down Expand Up @@ -116,6 +120,14 @@ class OffloadConfig:
dir: Optional[str] = None


def _is_te_module_with_weights(m: nn.Module) -> bool:
return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP))


def _is_fp8_dtype(dtype: torch.dtype) -> bool:
return dtype in [torch.float8_e5m2, torch.float8_e4m3fn]


class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
Expand Down Expand Up @@ -455,9 +467,24 @@ def __init__(
non_flatten_params = params
param_name_groups = [[n] for n in param_names]
if self.flatten_parameters:
to_be_flatten_params = [params]
non_flatten_params = []
param_name_groups = [param_names]
# don't flatten norm_weights since we need to handle them
# separately during fp8 training
to_be_flatten_params = [
[
params[i]
for i in range(len(params))
if "norm_weight" not in param_names[i]
]
]
non_flatten_params = [
params[i]
for i in range(len(params))
if "norm_weight" in param_names[i]
]
param_name_groups = [
[n for n in param_names if "norm_weight" not in n],
[n for n in param_names if "norm_weight" in n],
]
del param_names

self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
Expand Down Expand Up @@ -525,6 +552,9 @@ def __init__(
if self.zero2_process_group is not None:
assert not self.move_params_to_cpu

def _is_fp8_dtype(self) -> bool:
return _is_fp8_dtype(self.compute_dtype)

def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -654,7 +684,7 @@ def _cast_buffers(
@property
def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None]
return [p for p in self.parameters() if p.grad is not None or getattr(p, "main_grad", None) is not None]

@torch.no_grad()
def clip_grad_norm_(
Expand Down Expand Up @@ -757,7 +787,10 @@ def _shard_parameters_(self) -> None:
assert p.dtype == torch.float32

# If world_size is 1, then we all-reduce grads instead of sharding.
p._is_sharded = self.world_size > 1
# An exception is norm weights during fp8 training.
p._is_sharded = self.world_size > 1 and (
not self._is_fp8_dtype() or isinstance(p, FlatParameter)
)
p._orig_size = p.data.size()

if not p._is_sharded:
Expand Down Expand Up @@ -1141,7 +1174,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
non_shared_params
), f"{len(full_tensors)} vs. {len(non_shared_params)}"
for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors):
if not volatile:
if not volatile and p._is_sharded:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
Expand Down Expand Up @@ -1261,7 +1294,15 @@ def _init_param_attributes(self, p: Parameter) -> None:
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
p._fp16_shard = torch.zeros_like(
p._fp32_shard, device=self.compute_device, dtype=dtype
)
free_storage_(p._fp16_shard)

if self.mixed_precision:
Expand All @@ -1279,8 +1320,16 @@ def _init_param_attributes(self, p: Parameter) -> None:
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
if self._is_fp8_dtype() and not isinstance(p, FlatParameter):
# assume non flattened are precision critical like norm
assert not p._is_sharded
dtype = torch.bfloat16
else:
dtype = self.compute_dtype
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size,
device=self.compute_device,
dtype=dtype,
)
free_storage_(p._full_param_padded)

Expand Down Expand Up @@ -1393,20 +1442,66 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
is_bf16 = self.compute_dtype in [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_bf16_or_fp8

torch.bfloat16,
torch.float8_e5m2,
torch.float8_e4m3fn,
]
if self._is_root and self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

just_added_to_fsdp_forward_ordering = False
if self not in self._fsdp_forward_ordering:
self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering)
self._fsdp_forward_ordering.append(self)
just_added_to_fsdp_forward_ordering = True

# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs)

# need to use fp32_to_fp16 stream since _cast_fp32_param_shards_to_fp16
# depends on this block.
with torch.no_grad(), torch.cuda.stream(self._streams["fp32_to_fp16"]):
# Collect parameters to update scale/scale_inv before we
# _cast_fp32_param_shards_to_fp16 that uses fp8 scale to quantize
# before all-gather.
# These include params we prefetch all-gather.
params = []
if self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1:
if self._my_fsdp_instance_idx == 0 and self._is_fp8_dtype():
# The first FSDP instance didn't have chance to prefetch
params = self.params
if self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._is_fp8_dtype():
# FSDP instance we'll prefetch all-gather
params.extend(self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1].params)
elif just_added_to_fsdp_forward_ordering:
# In the first iteration, we didn't have chance to record
# fsdp_instance_idx to prefetch
if self._is_fp8_dtype():
params = self.params

for p in params:
if not isinstance(p, FlatParameter):
continue
d = {info[0]: info[1] for info in p._param_infos}
for n, m in d.items():
# Previous iteration was grad_enabled
if m.fp8_meta.get("update_amax_and_scale_fwd", False):
if m.fp8_meta["recipe"].reduce_amax:
FP8GlobalStateManager.copy_amax_from_global_buffer(
m.fp8_meta, forward=True
)
# FIXME update_weight_scale_inv is only True for the first micro-batch
amax_and_scale_update(m.fp8_meta, True)
FP8GlobalStateManager.set_amax_buffer_key_deletion(
m.fp8_meta, forward=True
)
else:
amax_and_scale_update(m.fp8_meta, True)

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
Expand Down Expand Up @@ -1648,7 +1743,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# then subsequent hook callbacks will see POST state.
self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST])
self.training_state = TrainingState.BACKWARD_POST
if param.grad is None:
grad_or_main_grad = (
param.main_grad if getattr(param, "main_grad", None) is not None else param.grad
)
if grad_or_main_grad is None:
return

if hasattr(param, "_linked_param"):
Expand All @@ -1661,8 +1759,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared:
param = param._linked_param

assert param.grad is not None, param.shape
if param.grad.requires_grad:
assert grad_or_main_grad is not None, param.shape
if grad_or_main_grad.requires_grad:
raise RuntimeError("FSDP only works with gradients that don't require gradients")

if self._require_backward_grad_sync or self.reshard_after_forward:
Expand All @@ -1689,23 +1787,32 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data
orig_grad_data = grad_or_main_grad

if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.to(param.dtype)
if self.mixed_precision:
if self.fp32_reduce_scatter:
# Cast grad to FP32.
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype)
elif self._is_fp8_dtype():
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not working with the latest FP8 wgrad ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This meant to be for future work when we have fp8 reduce-scatter. I'll update the comment.

grad_or_main_grad.data = grad_or_main_grad.to(torch.bfloat16)

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
grad_or_main_grad.div_(self.gradient_predivide_factor)

# logging.info(f"{torch.distributed.get_rank()=} {param._is_sharded=}")
if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if hasattr(param, "main_grad"):
grad = param.main_grad
param.main_grad = None
else:
grad = param.grad
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand All @@ -1727,8 +1834,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
self._post_reduction_hook(param, param.grad.data)
# assert self.world_size == 1
self._post_reduction_hook(param, grad_or_main_grad)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand Down Expand Up @@ -1840,6 +1947,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
assert p.shape == p._saved_grad_shard.shape
assert p.dtype == p._saved_grad_shard.dtype
assert getattr(p, "main_grad", None) is None
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
Expand Down Expand Up @@ -2166,6 +2276,15 @@ def _prep_grads_for_backward(self) -> None:
right shape, device, accumulated values, etc.
"""
for p in self.params:
if isinstance(p, FlatParameter) and all(
_is_te_module_with_weights(info[1]) for info in p._param_infos
):
if getattr(p, "main_grad", None) is None:
p.main_grad = torch.empty_like(p, dtype=torch.float)
main_grad_views = p.get_param_views(p.main_grad)
for (_, m, n), main_grad in zip(p._param_infos, main_grad_views):
getattr(m, n).main_grad = main_grad

if p.grad is not None:
if p.grad.device != p.data.device:
p.grad = None
Expand Down Expand Up @@ -2265,8 +2384,8 @@ def local_metadata_dict(self) -> Dict[str, Any]:
backing_param_name = m.module.flat_param_names[i]
names, shapes, numels = m.module.metadata(i)
else:
assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[i][0]
# assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to make sure checkpointing works properly with this.

names = [backing_param_name]
shapes = [p._orig_size]
numels = [p._orig_size.numel()]
Expand Down Expand Up @@ -2382,12 +2501,50 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype):
# fp8 quantization
assert isinstance(p, FlatParameter)
assert len(p._param_infos) == len(p._param_numels)
numel_per_shard = p.numel()
offset = -numel_per_shard * self.rank
for i in range(len(p._param_infos)):
_, m, n = p._param_infos[i]
numel = p._param_numels[i]
if offset + numel <= 0 or offset >= numel_per_shard:
offset += numel
continue
assert _is_te_module_with_weights(m)
fp8_dtype_forward = te.fp8.get_fp8_te_dtype(
m.fp8_meta["recipe"], fprop_tensor=True
)
if not m.fp8_initialized:
m.fp8_init(
num_gemms=2 if isinstance(m, te.LayerNormMLP) else 1
)
begin = max(offset, 0)
end = min(offset + numel, numel_per_shard)
cast_to_fp8(
p._fp32_shard[begin:end],
m.fp8_meta["scaling_fwd"],
FP8FwdTensors.GEMM2_WEIGHT
if n == "fc2_weight"
else FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
out=p._fp16_shard[begin:end],
)
offset += numel
p.data = p._fp16_shard.view(
torch.float8_e4m3fn
if fp8_dtype_forward == DType.kFloat8E4M3
else torch.float8_e5m2
)
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

@torch.no_grad()
Expand Down
Loading