diff --git a/main_training.py b/main_training.py index 02fa2919..6ddf28a6 100644 --- a/main_training.py +++ b/main_training.py @@ -133,14 +133,24 @@ def main(**kwargs): if m.reversible and not m.tie_weights: params_1d.append(m.head.weight) elif isinstance(m, MultiHeadAttention): - params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] + params_2d += [ + m.dense.weight, + ] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)] elif isinstance(m, GatedLinearUnit): params_2d += [m.wg1_fused.weight, m.w2.weight] optimizer = optim.AdamW( [ {"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr}, - {"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr / llama_config.emb_dim**.5}, - {"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim}, + { + "params": params_1d, + "lr": cfg.learning_rate + * llama_config.mup_1d_lr + / llama_config.emb_dim**0.5, + }, + { + "params": params_2d, + "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim, + }, ], betas=(0.9, 0.95), weight_decay=0.1,