diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 1e67685ed..413138597 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 455446c + Default = 53d0ae8 current git hash of repository @@ -1056,6 +1056,16 @@ Parallelism Arguments +- **sequence_parallel**: bool + + Default = False + + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + + + - **expert_interval**: int Default = 2 diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 619b4c33d..23be28936 100755 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -16,5 +16,8 @@ # limitations under the License. from .gpt2_model import GPT2ModelPipe -from .utils import get_params_for_weight_decay_optimization +from .utils import ( + get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, +) from .word_embeddings import SoftEmbedding diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..7899048db 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -308,7 +308,10 @@ def _logits_helper(embedding, lm_output): ) logits = parallel_lm_logits( - lm_output, embedding.word_embeddings_weight, self.parallel_output + lm_output, + embedding.word_embeddings_weight, + self.parallel_output, + seq_parallel=self.neox_args.sequence_parallel, ) return logits diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..62e7d3a9c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -254,6 +254,7 @@ def __init__( gather_output=not parallel_output, skip_bias_add=False, mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here + seq_dim=1, # important: must mark that this layer receives shape [b, s, h] not [s, b, h] and so Seq. Parallel comms must gather along dim=1 rather than dim=0 ) # else: @@ -1024,7 +1025,14 @@ def __init__( self.moe_type = neox_args.moe_type if self.gpt_j_residual: - self.reduce = mpu.mappings.reduce_from_model_parallel_region + # GPT-J style layers allow us to defer the reduction of results across TP ranks until the end of the two sublayers. + # the reduction we use is a simple allreduce for pure Tensor Parallel, + # but needs to be a reduce-scatter when using Megatron-style Sequence Parallel (LN sharding.) + self.reduce = ( + mpu.mappings.reduce_from_model_parallel_region + if not neox_args.sequence_parallel + else mpu.mappings.reduce_scatter_to_sequence_parallel_region + ) # Self attention. self.attention = ParallelSelfAttention( @@ -1339,10 +1347,25 @@ def forward(self, args): return self.norm(args) -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): +def parallel_lm_logits( + input_, + word_embeddings_weight, + parallel_output, + seq_parallel=False, + seq_dim=1, + bias=None, +): """LM logits using word embedding weights.""" # Parallel logits. - input_parallel = mpu.copy_to_model_parallel_region(input_) + if seq_parallel: + # if using Sequence Parallelism, our logits are sharded along the sequence dimension. + # gather them here. (backward pass: reduce-scatter) + input_parallel = mpu.gather_from_sequence_parallel_region( + input_, seq_dim=seq_dim + ) + else: + # Set up backprop all-reduce. + input_parallel = mpu.copy_to_model_parallel_region(input_) # Matrix multiply. if bias is None: diff --git a/megatron/model/utils.py b/megatron/model/utils.py index c3da2ce8b..97b409c1d 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -18,8 +18,8 @@ """Utilities for models.""" import torch -from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm from megatron.model.fused_softmax import SoftmaxFusionTypes +from megatron import mpu from types import GeneratorType import torch.distributed as dist @@ -35,15 +35,9 @@ def get_params_for_weight_decay_optimization(module, neox_args): "name": "no_weight_decay_params", } for module_ in module.modules(): - if any( - [ - isinstance(module_, LayerNorm), - isinstance(module_, RMSNorm), - isinstance(module_, ScaleNorm), - ] - ) or ( - neox_args.weight_decay == 0.0 - ): # also include all parameters here if no weight decay is being done + # apply weight decay to any "...Norm" modules. + if "norm" in type(module_).__name__.lower() or neox_args.weight_decay == 0.0: + # also include all parameters here if no weight decay is being done no_weight_decay_params["params"].extend( [p for p in list(module_._parameters.values()) if p is not None] ) @@ -359,3 +353,45 @@ def get_fusion_type(neox_args): elif neox_args.scaled_masked_softmax_fusion: fusion_type = SoftmaxFusionTypes.general return fusion_type + + +def reduce_weight_grads_from_model_parallel_region(input_): + """A hook that can be applied to any weight tensor via .register_hook(). + Allreduces grads for e.g. LN weights across the model parallel group. + Needed to keep LNs in sync, despite them getting diff data -> diff gradients when using sequence parallel. + """ + # Bypass the function if no TP -> no comm needed. + if mpu.get_model_parallel_world_size() == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.float() + + # All-reduce. + torch.distributed.all_reduce(input_, group=mpu.get_model_parallel_group()) + + # Bf16 convert + if dt == torch.bfloat16 and mpu.get_fp32_allreduce(): + input_ = input_.bfloat16() + + return input_ + + +def mark_norms_for_sequence_parallel_grad_sync(module, neox_args): + """Iterate through the modules in our model, and for any "...Norm" classnames, + register a hook on each of that module's parameters which will allreduce norms' weights' grads across + the model (sequence) parallel region. + """ + + if not neox_args.sequence_parallel: + # if we aren't using sequence parallelism, this is a no-op + return + + for module_ in module.modules(): + if "norm" in type(module_).__name__.lower(): + # this is a norm, we want to allreduce its weight grads across sequence parallel region + for name, param in module_.named_parameters(): + if param.requires_grad: + param.register_hook(reduce_weight_grads_from_model_parallel_region) diff --git a/megatron/model/word_embeddings.py b/megatron/model/word_embeddings.py index f7372bc55..ce3c1117e 100644 --- a/megatron/model/word_embeddings.py +++ b/megatron/model/word_embeddings.py @@ -50,6 +50,11 @@ def __init__( self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes + + self.sequence_parallel = ( + neox_args.sequence_parallel + ) # if we are using sequence parallelism, then we'll want to scatter our inputs across the seqlen dim across TP ranks + self.use_mup = neox_args.use_mup self.mup_embedding_mult = neox_args.mup_embedding_mult self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult @@ -159,6 +164,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): with torch.no_grad(): embeddings.mul_(self.mup_embedding_mult) + if self.sequence_parallel: + # TODO: megatron-lm does dropout using the scattered embs. This would save a tiny bit of time, perhaps? + # Not a priority since we don't often use dropout + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + return embeddings diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 2365507d9..780fb33e8 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -47,6 +47,9 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region +from .mappings import scatter_to_sequence_parallel_region from .random import checkpoint from .random import get_cuda_rng_tracker diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 0d14806ac..d59edab94 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -33,6 +33,8 @@ from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region +from .mappings import reduce_scatter_to_sequence_parallel_region +from .mappings import gather_from_sequence_parallel_region from .random import get_cuda_rng_tracker from .utils import divide from .utils import VocabUtility @@ -416,6 +418,7 @@ def __init__( MOE=False, MoE_mp_size=1, mup_rescale_parameters=False, + seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout. ): super(ColumnParallelLinear, self).__init__() @@ -427,6 +430,10 @@ def __init__( world_size = MoE_mp_size if MOE else get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) self.skip_bias_add = skip_bias_add + + self.sequence_parallel = neox_args.sequence_parallel + self.seq_dim = seq_dim + self.init_method = init_method self.stride = stride self.mup_rescale_parameters = mup_rescale_parameters @@ -551,14 +558,29 @@ def set_parallel_output(self, value: bool): def forward(self, input_): if self.use_mup and self.mup_rescale_parameters: input_ /= self.width_mult() - # Set up backprop all-reduce. - input_parallel = copy_to_model_parallel_region(input_) + + if self.sequence_parallel: + input_parallel = input_ + else: + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. + if self.sequence_parallel: + # do an AG in the fwd pass, RS in bwd pass. + # gather / scatter portion happens across the sequence dim (self.seq_dim)-- + # almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h] + input_parallel = gather_from_sequence_parallel_region( + input_parallel, seq_dim=self.seq_dim + ) + bias = self.bias if not self.skip_bias_add else None output_parallel = F.linear(input_parallel, self.weight, bias) if self.gather_output: # All-gather across the partitions. + assert ( + not self.sequence_parallel + ), "sequence_parallel=True and gather_output=True are incompatible!" output = gather_from_model_parallel_region(output_parallel) else: output = output_parallel @@ -623,6 +645,12 @@ def __init__( self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add self.parallel_output = parallel_output + + self.sequence_parallel = neox_args.sequence_parallel + assert not ( + self.sequence_parallel and not self.input_is_parallel + ), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True." + self.init_method = init_method self.stride = stride self.keep_master_weight_for_test = keep_master_weight_for_test @@ -748,7 +776,12 @@ def forward(self, input_): # Matrix multiply. output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - if not self.parallel_output: + if self.sequence_parallel and not self.parallel_output: + # do an RS in the fwd pass, AG in bwd pass. + # skip in the gpt-j parallel sublayer case (self.parallel_output=True) + # (user responsible for calling reduce-scatter) + output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) + elif not self.parallel_output: output_ = reduce_from_model_parallel_region(output_parallel) else: output_ = output_parallel diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 535fe6255..f11d9e6ab 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -23,7 +23,7 @@ get_model_parallel_rank, get_fp32_allreduce, ) -from .utils import split_tensor_along_last_dim +from .utils import split_tensor_along_last_dim, split_tensor_along_any_dim def _reduce(input_): @@ -33,17 +33,17 @@ def _reduce(input_): if get_model_parallel_world_size() == 1: return input_ - # Bf16 convert + # upcast to fp32 if using fp32 allreduce dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): + if get_fp32_allreduce(): input_ = input_.float() # All-reduce. torch.distributed.all_reduce(input_, group=get_model_parallel_group()) - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.bfloat16() + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + input_ = input_.to(dt) return input_ @@ -75,11 +75,6 @@ def _gather(input_): if world_size == 1: return input_ - # Bf16 convert - dt = input_.dtype - if dt == torch.bfloat16 and get_fp32_allreduce(): - input_ = input_.float() - # Size and dimension. last_dim = input_.dim() - 1 rank = get_model_parallel_rank() @@ -91,9 +86,100 @@ def _gather(input_): # Note: torch.cat already creates a contiguous tensor. output = torch.cat(tensor_list, dim=last_dim).contiguous() - # Bf16 convert - if dt == torch.bfloat16 and get_fp32_allreduce(): - output = output.bfloat16() + return output + + +def _reduce_scatter_along_seq_dim(input_, seq_dim): + """Reduce-scatter the input tensor across model parallel group, scattering across sequence dim.""" + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # upcast to fp32 if using fp32 allreduce + dt = input_.dtype + if get_fp32_allreduce(): + input_ = input_.float() + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + # reduce_scatter_tensor is faster but only works correctly on dimension 0 + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) + + # reconvert to original Bf16/Fp16 dtype + if get_fp32_allreduce(): + output = output.to(dt) + + return output + + +def _gather_along_seq_dim(input_, seq_dim): + """Gather tensors and concatinate along the (manually-specified) sequence dimension.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + # reduce_gather_tensor is faster but only works correctly on dimension 0 + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) + + return output + + +def _split_along_seq_dim(input_, seq_dim): + """Split the tensor along the sequence dimension (as manually selected) and keep the + corresponding slice.""" + + world_size = get_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along second dimension. + input_list = split_tensor_along_any_dim(input_, world_size, seq_dim) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_model_parallel_rank() + output = input_list[rank].contiguous() return output @@ -162,6 +248,65 @@ def backward(ctx, grad_output): return _split(grad_output) +class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): + """Reduce-Scatter across sequence parallel region (same as model parallel region.) + Note: same region as model parallel region + """ + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _reduce_scatter_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _gather_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _GatherFromSequenceParallelRegion(torch.autograd.Function): + """All-Gather across sequence parallel region (same region as model parallel region.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _gather_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return _reduce_scatter_along_seq_dim(grad_output, seq_dim=seq_dim), None + + +class _ScatterToSequenceParallelRegion(torch.autograd.Function): + """Scatter (split) sequence length across sequence parallel region (=> same region as model parallel.)""" + + @staticmethod + def symbolic(graph, input_, seq_dim): + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def forward(ctx, input_, seq_dim): + ctx.seq_dim = seq_dim + return _split_along_seq_dim(input_, seq_dim=seq_dim) + + @staticmethod + def backward(ctx, grad_output): + seq_dim = ctx.seq_dim + return ( + _gather_along_seq_dim(grad_output, seq_dim=seq_dim), + None, + ) + + # ----------------- # Helper functions. # ----------------- @@ -181,3 +326,17 @@ def scatter_to_model_parallel_region(input_): def gather_from_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + + +def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0): + return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim) + + +def gather_from_sequence_parallel_region(input_, seq_dim=0): + return _GatherFromSequenceParallelRegion.apply(input_, seq_dim) + + +def scatter_to_sequence_parallel_region( + input_, seq_dim=1 +): # use this fn in scattering input embeds across TP ranks. There, shape of inps is [b, s, h] instead of the usual [s, b, h] + return _ScatterToSequenceParallelRegion.apply(input_, seq_dim) diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 13941dc29..1f97e0e76 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -53,6 +53,28 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list +def split_tensor_along_any_dim( + tensor, num_partitions, seq_dim, contiguous_split_chunks=False +): + """Split a tensor along a user-specified dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + seq_dim: dimension along which to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + seq_dim_size = divide(tensor.size()[seq_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, seq_dim_size, dim=seq_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the first and last index of the vocabulary belonging to the `rank` diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 054689eda..6a84df6c7 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1038,6 +1038,10 @@ def calculate_derived(self): assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" assert self.mlp_type == "regular", "MoE not compatible with LLaMA" + assert ( + self.sequence_parallel is False + ), "MoE not compatible with Sequence Parallel" + # Attention config if self.attention_config is None: self.update_value("attention_config", [[["global"], self.num_layers]]) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..7993f785f 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -85,6 +85,13 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + sequence_parallel: bool = False + """ + flag to determine whether Megatron-style Sequence Parallelism (https://arxiv.org/abs/2205.05198) + (Layernorm inputs and activations are sharded across model parallel group) will be used. Has no effect when model_parallel_size is 1. + **Set by user, in contrast to neox_args.is_pipe_parallel.** + """ + expert_interval: int = 2 """ Have one MoE layer every expert_interval layers diff --git a/megatron/training.py b/megatron/training.py index 3265680c5..ce59b242a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -43,6 +43,7 @@ GPT2ModelPipe, SoftEmbedding, get_params_for_weight_decay_optimization, + mark_norms_for_sequence_parallel_grad_sync, ) from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators @@ -765,6 +766,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. model.has_moe_layers = True @@ -891,6 +893,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and neox_args.iteration <= neox_args.profile_step_stop ): torch.cuda.nvtx.range_push(f"Optimizer step") + timers("optimizer").start() if neox_args.deepspeed: model.step()