-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 AllGather Support in Fairscale #1185
base: ngoyal_changes_for_pp_fp8_jiecaoyu_debug
Are you sure you want to change the base?
FP8 AllGather Support in Fairscale #1185
Conversation
Co-authored-by: Naman Goyal <[email protected]>
This commit works with a 4 GPU run on SMALL model with FSDP and PP enabled.
- Clean up flatten and non_flatten parameter generation logic. - Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be done for both weights and parameters. So it should be done at forward of each microbatch. - Consolidate `cast_params` and `all_gather` stream.
Co-authored-by: Naman Goyal <[email protected]>
This commit works with a 4 GPU run on SMALL model with FSDP and PP enabled.
- Clean up flatten and non_flatten parameter generation logic. - Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be done for both weights and parameters. So it should be done at forward of each microbatch. - Consolidate `cast_params` and `all_gather` stream.
…kresearch/fairscale into shikaili_fp8_allgather_no_pp_fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @levendlee for the great work! I left some comments for my own learning.
and all(_is_te_module_with_weights(info[1]) for info in p._param_infos)) | ||
if fused_wgard_accumulation: | ||
if getattr(p, "main_grad", None) is None: | ||
p.main_grad = torch.empty_like(p, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, why empty_like
instead of zeros_like
?
if params is None: | ||
params = self.params | ||
with torch.cuda.stream(self._streams["fp32_to_fp16"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why did you use the "all_gather"
stream instead of the "fp32_to_fp16"
stream?
@@ -2087,6 +2179,9 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: | |||
|
|||
self.has_full_params = False | |||
|
|||
if self.fp8_all_gather: | |||
self._update_amax_and_scale_fwd(is_first_microbatch_fwd=is_first_microbatch_fwd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, is there a reason that this is not done together with _cast_params_for_all_gather
? (For example, could this call be delayed a few lines to below where _cast_params_for_all_gather
is called?)
|
||
|
||
|
||
@torch.no_grad() | ||
def _rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]: | ||
def _rebuild_full_params( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp8_all_gather=True
, what happens when this method is called without the TE autocast context?
@@ -1448,16 +1505,22 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: | |||
|
|||
# All-gather full parameters. This will also transfer FP32 parameters to | |||
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). | |||
self._rebuild_full_params() | |||
self.module.has_unflatten_views = getattr(self.module, "has_unflatten_views", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
What does this PR do?
Fixes # (issue).
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.