Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer committed Jul 22, 2024
1 parent 4dd3998 commit 1491706
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions main_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,24 @@ def main(**kwargs):
if m.reversible and not m.tie_weights:
params_1d.append(m.head.weight)
elif isinstance(m, MultiHeadAttention):
params_2d += [m.dense.weight,] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)]
params_2d += [
m.dense.weight,
] + [m_.weight for m_ in m.in_proj.modules() if isinstance(m_, nn.Linear)]
elif isinstance(m, GatedLinearUnit):
params_2d += [m.wg1_fused.weight, m.w2.weight]
optimizer = optim.AdamW(
[
{"params": params_0d, "lr": cfg.learning_rate * llama_config.mup_0d_lr},
{"params": params_1d, "lr": cfg.learning_rate * llama_config.mup_1d_lr / llama_config.emb_dim**.5},
{"params": params_2d, "lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim},
{
"params": params_1d,
"lr": cfg.learning_rate
* llama_config.mup_1d_lr
/ llama_config.emb_dim**0.5,
},
{
"params": params_2d,
"lr": cfg.learning_rate * llama_config.mup_2d_lr / llama_config.emb_dim,
},
],
betas=(0.9, 0.95),
weight_decay=0.1,
Expand Down

0 comments on commit 1491706

Please sign in to comment.