Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fp8 all gather hack #1136

Open
wants to merge 3 commits into
base: ngoyal_added_zero2_shard_modelparams_multiple_gpus
Choose a base branch
from

Conversation

jspark1105
Copy link

@jspark1105 jspark1105 commented Sep 17, 2023

This is based on ngoyal_added_zero2_shard_modelparams_multiple_gpus and adding hacks to use fp8 all-gather with Nvidia's transformer engine (see the latest commit for the changes on top of ngoyal_added_zero2_shard_modelparams_multiple_gpus branch).

This depends on transformer engine changes in https://github.com/facebookresearch/TransformerEngine/pull/20
See https://github.com/fairinternal/xlformers/pull/1403 for an example how to use.
Also depends on PyTorch changes in pytorch/pytorch#109654

To use fp8 allgather, set compute_dtype=torch.float8_e4m3fn and mixed_precision=True
We separate out precision critical parameters like affine weights for norm as non flattened params and hard-code to use bf16.
We update scale/scale_inv inside forward before _rebuild_full_params that calls _cast_fp32_param_shards_to_fp16 vs. TE baseline that updates scale/scale_inv in prepare_forward. This because we need fp8 quantization of weights earlier before allgather. (One can consider doing this in post backward but this has a problem since updating bwd amax update is done after bwd of all layers are finished which can be later than post backward so we won't be using the latest bwd amax info for scale/scale_inv update).
We hard-code special handling for a couple of TransformerEngine layers like Linear, LayerNormLinear, and LayerNormMLP in _cast_fp32_param_shards_to_fp16 to access their fp8 meta data to quantize with right scales (TODO: we may want to extract this as a user customizable call back functions?)

CC @awgu @ngoyal2707 @vedanuj @jiecaoyu @yf225 @GD06

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 17, 2023
@jspark1105 jspark1105 marked this pull request as ready for review September 18, 2023 04:33
@jspark1105 jspark1105 changed the base branch from main to ngoyal_added_zero2_shard_modelparams_multiple_gpus October 4, 2023 23:05
@jspark1105
Copy link
Author

Will merge main_grad related changes with #1142

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from bd70153 to af3d2d7 Compare October 5, 2023 03:49
# Cast grad to FP32.
grad_or_main_grad.data = grad_or_main_grad.to(param.dtype)
elif self._is_fp8_dtype():
# Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this is not working with the latest FP8 wgrad ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This meant to be for future work when we have fp8 reduce-scatter. I'll update the comment.

@@ -1393,7 +1447,11 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
is_bf16 = self.compute_dtype in [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_bf16_or_fp8

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from b9b093b to a2b49d1 Compare October 7, 2023 03:10
@@ -2265,8 +2361,7 @@ def local_metadata_dict(self) -> Dict[str, Any]:
backing_param_name = m.module.flat_param_names[i]
names, shapes, numels = m.module.metadata(i)
else:
assert len(m._param_name_groups[i]) == 1
backing_param_name = m._param_name_groups[i][0]
backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to make sure checkpointing works properly with this.

@jspark1105 jspark1105 force-pushed the fp8_all_gather branch 2 times, most recently from d92dc0f to 6a4d7f4 Compare October 15, 2023 18:56
@facebook-github-bot
Copy link

Hi @jspark1105!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants