Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron-LM style Sequence Parallel #1257

Merged
merged 28 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ad4f0a4
first draft (shape errors occurring)
haileyschoelkopf Jul 30, 2024
dc4c99b
training works (but poor convergence)
haileyschoelkopf Jul 31, 2024
3ccd3ba
debugging progress: current commit works if we do regular TP via impl…
haileyschoelkopf Aug 1, 2024
73aa0aa
Update NeoXArgs docs automatically
invalid-email-address Aug 8, 2024
92ed0cc
push most recent code (updated mark_norms fn, back to 'real' sequence…
haileyschoelkopf Aug 8, 2024
c93c1b4
Merge branch '812-megatron-seq-parallel' of https://github.com/Eleuth…
haileyschoelkopf Aug 8, 2024
9c1e7b9
Update NeoXArgs docs automatically
invalid-email-address Aug 8, 2024
651e24e
Fix LayerNorm all reduce gradient hook
bclyang Aug 12, 2024
9a43318
Sum instead of average for LayerNorm gradient all reduce
bclyang Aug 14, 2024
c0561d6
Update NeoXArgs docs automatically
invalid-email-address Aug 14, 2024
9945910
Merge pull request #1259 from EleutherAI/fix-ln-hooks
haileyschoelkopf Aug 14, 2024
2c5dc5a
Update NeoXArgs docs automatically
invalid-email-address Aug 14, 2024
9d883de
Fix gather and reduce scatter ops on sequence dimension
bclyang Aug 17, 2024
28a5a62
Fix sequence parallel with tied weight embeddings
bclyang Aug 18, 2024
5427d9d
Update NeoXArgs docs automatically
invalid-email-address Aug 18, 2024
b0d9398
Merge pull request #1260 from EleutherAI/fix-seq-dim-reducegatherdactter
haileyschoelkopf Aug 19, 2024
8f26029
cleanup pass + add MoE arguments.py guard
haileyschoelkopf Aug 19, 2024
d9db749
pre-commit and clean up comments
Quentin-Anthony Aug 19, 2024
aafbbce
remove vestigial debug code
Quentin-Anthony Aug 19, 2024
ba682e7
remove unused debugging code
haileyschoelkopf Aug 19, 2024
8455de7
remove dummy test config
haileyschoelkopf Aug 19, 2024
9ce982e
Merge branch '812-megatron-seq-parallel' of https://github.com/Eleuth…
haileyschoelkopf Aug 19, 2024
ab11a6a
update fp32_allreduce to handle fp16 ; don't cast to fp32 for gathers
haileyschoelkopf Aug 19, 2024
f26b886
run linter on the rest of the files
haileyschoelkopf Aug 19, 2024
8e7400f
Improve performance of sequence parallel gather, scatter, and reduce
bclyang Aug 22, 2024
53d0ae8
Add comment
bclyang Aug 22, 2024
05f5cec
Update NeoXArgs docs automatically
invalid-email-address Aug 22, 2024
1661db6
Merge pull request #1263 from EleutherAI/improve-seq-parallel-perf
haileyschoelkopf Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 455446c
Default = 53d0ae8

current git hash of repository

Expand Down Expand Up @@ -1056,6 +1056,16 @@ Parallelism Arguments



- **sequence_parallel**: bool
haileyschoelkopf marked this conversation as resolved.
Show resolved Hide resolved

Default = False

flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198)
(Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1.
**Set by user, in contrast to neox_args.is_pipe_parallel.**



- **expert_interval**: int

Default = 2
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@
# limitations under the License.

from .gpt2_model import GPT2ModelPipe
from .utils import get_params_for_weight_decay_optimization
from .utils import (
get_params_for_weight_decay_optimization,
mark_norms_for_sequence_parallel_grad_sync,
)
from .word_embeddings import SoftEmbedding
5 changes: 4 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output):
)

logits = parallel_lm_logits(
lm_output, embedding.word_embeddings_weight, self.parallel_output
lm_output,
embedding.word_embeddings_weight,
self.parallel_output,
seq_parallel=self.neox_args.sequence_parallel,
)
return logits

Expand Down
29 changes: 26 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def __init__(
gather_output=not parallel_output,
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0
)

# else:
Expand Down Expand Up @@ -1024,7 +1025,14 @@ def __init__(
self.moe_type = neox_args.moe_type

if self.gpt_j_residual:
self.reduce = mpu.mappings.reduce_from_model_parallel_region
# GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers.
# the reduction we use is a simple allreduce for pure Tensor Parallel,
# but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.)
self.reduce = (
mpu.mappings.reduce_from_model_parallel_region
if not neox_args.sequence_parallel
else mpu.mappings.reduce_scatter_to_sequence_parallel_region
)

# Self attention.
self.attention = ParallelSelfAttention(
Expand Down Expand Up @@ -1339,10 +1347,25 @@ def forward(self, args):
return self.norm(args)


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
def parallel_lm_logits(
input_,
word_embeddings_weight,
parallel_output,
seq_parallel=False,
seq_dim=1,
bias=None,
):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
if seq_parallel:
# if using Sequence Parallelism, our logits are sharded along the sequence dimension.
# gather them here. (backward pass: reduce-scatter)
input_parallel = mpu.gather_from_sequence_parallel_region(
input_, seq_dim=seq_dim
)
else:
# Set up backprop all-reduce.
input_parallel = mpu.copy_to_model_parallel_region(input_)

# Matrix multiply.
if bias is None:
Expand Down
56 changes: 46 additions & 10 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"""Utilities for models."""

import torch
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.fused_softmax import SoftmaxFusionTypes
from megatron import mpu
from types import GeneratorType
import torch.distributed as dist

Expand All @@ -35,15 +35,9 @@ def get_params_for_weight_decay_optimization(module, neox_args):
"name": "no_weight_decay_params",
}
for module_ in module.modules():
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
# apply weight decay to any "...Norm" modules.
if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0:
# also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
Expand Down Expand Up @@ -359,3 +353,45 @@ def get_fusion_type(neox_args):
elif neox_args.scaled_masked_softmax_fusion:
fusion_type = SoftmaxFusionTypes.general
return fusion_type


def reduce_weight_grads_from_model_parallel_region(input_):
"""A hook that can be applied to any weight tensor via .register_hook().
Allreduces grads for e.g. LN weights across the model parallel group.
Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel.
"""
# Bypass the function if no TP -> no comm needed.
if mpu.get_model_parallel_world_size() == 1:
return input_

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
haileyschoelkopf marked this conversation as resolved.
Show resolved Hide resolved
input_ = input_.float()

# All-reduce.
torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
input_ = input_.bfloat16()

return input_


def mark_norms_for_sequence_parallel_grad_sync(module, neox_args):
"""Iterate through the modules in our model, and for any "...Norm" classnames,
register a hook on each of that module's parameters which will allreduce norms' weights' grads across
the model (sequence) parallel region.
"""

if not neox_args.sequence_parallel:
# if we aren't using sequence parallelism, this is a no-op
return

for module_ in module.modules():
if "norm" in type(module_).__name__.lower():
Quentin-Anthony marked this conversation as resolved.
Show resolved Hide resolved
# this is a norm, we want to allreduce its weight grads across sequence parallel region
for name, param in module_.named_parameters():
if param.requires_grad:
param.register_hook(reduce_weight_grads_from_model_parallel_region)
10 changes: 10 additions & 0 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def __init__(
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes

self.sequence_parallel = (
neox_args.sequence_parallel
) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks

self.use_mup = neox_args.use_mup
self.mup_embedding_mult = neox_args.mup_embedding_mult
self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult
Expand Down Expand Up @@ -159,6 +164,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
with torch.no_grad():
embeddings.mul_(self.mup_embedding_mult)

if self.sequence_parallel:
# TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps?
# Not a priority since we don't often use dropout
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)

return embeddings


Expand Down
3 changes: 3 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import scatter_to_sequence_parallel_region

from .random import checkpoint
from .random import get_cuda_rng_tracker
Expand Down
39 changes: 36 additions & 3 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
Expand Down Expand Up @@ -416,6 +418,7 @@ def __init__(
MOE=False,
MoE_mp_size=1,
mup_rescale_parameters=False,
seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout.
):
super(ColumnParallelLinear, self).__init__()

Expand All @@ -427,6 +430,10 @@ def __init__(
world_size = MoE_mp_size if MOE else get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add

self.sequence_parallel = neox_args.sequence_parallel
self.seq_dim = seq_dim

self.init_method = init_method
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
Expand Down Expand Up @@ -551,14 +558,29 @@ def set_parallel_output(self, value: bool):
def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)

if self.sequence_parallel:
input_parallel = input_
else:
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.

if self.sequence_parallel:
# do an AG in the fwd pass, RS in bwd pass.
# gather / scatter portion happens across the sequence dim (self.seq_dim)--
# almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h]
input_parallel = gather_from_sequence_parallel_region(
input_parallel, seq_dim=self.seq_dim
)

bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
assert (
not self.sequence_parallel
), "sequence_parallel=True and gather_output=True are incompatible!"
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
Expand Down Expand Up @@ -623,6 +645,12 @@ def __init__(
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.parallel_output = parallel_output

self.sequence_parallel = neox_args.sequence_parallel
assert not (
self.sequence_parallel and not self.input_is_parallel
), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True."

self.init_method = init_method
self.stride = stride
self.keep_master_weight_for_test = keep_master_weight_for_test
Expand Down Expand Up @@ -748,7 +776,12 @@ def forward(self, input_):
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
if not self.parallel_output:
if self.sequence_parallel and not self.parallel_output:
# do an RS in the fwd pass, AG in bwd pass.
# skip in the gpt-j parallel sublayer case (self.parallel_output=True)
# (user responsible for calling reduce-scatter)
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
elif not self.parallel_output:
output_ = reduce_from_model_parallel_region(output_parallel)
else:
output_ = output_parallel
Expand Down
Loading
Loading