You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I test model with full recomputation, the forward all-gather communication is not overlapped. Because is_grad_enabled is false when forward with full recomputation. I see the following code in _LayerNormLinear class:
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled):
ub_overlap_ag = False
if ub_overlap_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name + "_fprop")
if return_layernorm_output:
# First prepare LN output in higher precision,
# which will be later copied to a FP8 UB
ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format)
else:
ln_out = ub_obj_lnout.get_ubuf_output(0)
Why ub_overlap_ag is set to False in '(not is_grad_enabled)' condition?
The text was updated successfully, but these errors were encountered:
When I test model with full recomputation, the forward all-gather communication is not overlapped. Because is_grad_enabled is false when forward with full recomputation. I see the following code in _LayerNormLinear class:
Why ub_overlap_ag is set to False in '(not is_grad_enabled)' condition?
The text was updated successfully, but these errors were encountered: