Skip to content

Commit

Permalink
fix activation checkpointing and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jahatef committed Dec 18, 2024
1 parent bdb3658 commit 945ae22
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 12 deletions.
13 changes: 7 additions & 6 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,15 @@ def get_flops(neox_args, iter_time_s) -> float:
num_heads = neox_args.num_attention_heads

flops_per_iteration = (
batch_size
ckpt_activations_factor
* batch_size
* seq_len
* (
78 * hidden_size * hidden_size * num_layers
+ 84 * hidden_size * num_layers
+ 16 * hidden_size
+ 12 * hidden_size * vocab_size
+ 18 * hidden_size * hidden_size * num_layers / num_heads
26 * hidden_size * hidden_size * num_layers
+ 928 * hidden_size * num_layers
+ 8 * hidden_size
+ 4 * hidden_size * vocab_size
+ 6 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
Expand Down
1 change: 1 addition & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
"RWKVResidualLayerPipe"
],
)

Expand Down
8 changes: 3 additions & 5 deletions megatron/model/rwkv/v6/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def forward(self, x):

r, k, v, g, w = self.jit_func(x)
if self.neox_args.rwkv_fla:
x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))
x, _ = RUN_FLA_CHUNK(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H),chunk_size=256)
else:
x = RUN_CUDA_RWKV(B, T, C_tp, H, r, k, v, w, u=scatter_to_model_parallel_region(self.time_faaaa.view(-1)).view(H,C_tp//H))

Expand Down Expand Up @@ -418,11 +418,9 @@ def __init__(self, neox_args, init_method, layer_number):
],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-ffast-math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
"-fvectorize",
f"-D_N_={self.neox_args.head_size}",
f"-D_T_={self.neox_args.seq_length}",
],
Expand Down
3 changes: 2 additions & 1 deletion megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def reduce_scatter_to_sequence_parallel_region(input_, seq_dim=0):
return _ReduceScatterToSequenceParallelRegion.apply(input_, seq_dim)


def gather_from_sequence_parallel_region(input_, seq_dim=0):
def gather_from_sequence_parallel_region(input_: torch.Tensor, seq_dim: int = 0):
#def gather_from_sequence_parallel_region(input_, seq_dim=0):
return _GatherFromSequenceParallelRegion.apply(input_, seq_dim)


Expand Down

0 comments on commit 945ae22

Please sign in to comment.