-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Bugfix for wgrad bulk overlap conflict when dgrad overlap is reduce-scatter #1341
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI
/te-ci pytorch L0 L1 |
ab9e05f
to
2ca29de
Compare
/te-ci pytorch L0 L1 |
if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": | ||
# Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a | ||
# reduce-scatter dgrad overlap. | ||
ub_cfg = {} if ub_cfg is None else ub_cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local variable ub_cfg
referenced before assignment.
ub_cfg
--> ub_cfgs
?
methods[new_method].append(name) | ||
|
||
ub_cfg[name] = final_cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ub_cfg
--> ub_cfgs
?
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: | ||
ub_cfg = get_default_config(name) | ||
if ub_cfgs is not None and name in ub_cfgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ub_cfgs is not None and name in ub_cfgs
==> if name in ub_cfgs
fp8_buf = (name in layers_all_gather_overlap) or ( | ||
final_cfg = get_default_config(name) | ||
final_cfg.update(ub_cfgs[name]) | ||
final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I am using 'bf16' dtype for training, and name in layers_all_gather_overlap
is true for some pattern (e.g., fc2_dgrad
), then final_cfg["fp8_buf"]
will be set to True, which is unexpected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay line 361 has already using use_fp8
to guard the effectiveness of fp8_buf
, but it would be better if we can advance this guard when we set the default value of cfg['fp8_buf']
to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't advance this guard because it's not only for all-gather overlaps. I'll run through the logic to explain.
use_fp8
tells the initialization whether the user intends to invoke comm+GEMM overlap under the te.fp8_autocast()
context (i.e. with FP8 GEMM inputs).
fp8_buf
tells the initialization whether the communication buffer should be allocated as FP8.
For all-gather overlaps, the buffer has to match the GEMM input type, so it will always be allocated in FP8 when use_fp8 == True
(i.e. GEMM inputs are FP8) regardless of what fp8_buf
is set to in the user's layer configuration. In other words, the user does not get a choice here.
For reduce-scatter overlaps, the GEMM output has to match the buffer type, which can be either FP8 or BF16 when GEMM inputs are FP8. In this scenario, setting fp8_buf = True
means that we communicate FP8 data between devices, and then fuse the BF16 upcast into the sum-reduce.
Advancing this guard to the default config options means that the user is denied the option to set fp8_buf
for reduce-scatter overlaps, and that RS overlaps always communicate BF16 data, which is not always the optimal choice.
On a side note, the name in method["pipeline"]
part of this is oudated and needs to be removed because we support optional FP8 GEMM outputs/buffers in all reduce-scatter overlaps now, not just collective/pipeline methods.
aad8294
to
d2d9938
Compare
…dgrad overlap is enabled Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
…er logic Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Alp Dener <[email protected]>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <[email protected]>
1036168
to
1797f14
Compare
Description
When Userbuffers config dictionary sets overlap method to
ring-exchange
orpipeline
for any*_dgrad
layer, that layer's*_wgrad
overlap needs to be disabled in order forub_overlap_rs_dgrad=True
option for related TE modules to function correctly.This PR fixes a bug where the "*_wgrad" overlap was persisting in the Userbuffer configuration and the corresponding UB object was being initialized even when it was not needed.
Type of change
Changes
Please list the changes introduced in this PR:
*_wgrad
overlap is now removed frommethods["bulk"]
list when the same layer's*_dgrad
overlap has its method set to eitherring-exchange
orpipeline
.add_ub(name, **ub_cfg)
is now only called ifname
is in the original user-providedub_cfg
. This avoids creating UB objects with default configs that may conflict with the user's intended TP overlap use.Checklist: