Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[4] Clean train config by using hydra config overriding #6

Merged
merged 19 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions examples/poyo/configs/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Base config for training poyo_plus

data_root: ./data/processed/
log_dir: ./logs
seed: 42

batch_size: 128
eval_batch_size: null # if null, will use batch_size
num_workers: 4

epochs: 1000 # if -1, will use steps
steps: -1 # if -1, will use epochs. Epochs take precedence.
eval_epochs: 1

optim:
base_lr: 3.125e-5 # scaled linearly by batch size
weight_decay: 1e-4
lr_decay_start: 0.5 # fraction of epochs before starting LR decay

wandb:
enable: true
entity: null
project: poyo
run_name: null
log_model: false

backend_config: gpu_fp32
precision: 32
nodes: 1
gpus: 1

# Where to resume/finetune from. Could be null (yaml for None, meaning train from
# scratch) or a fully qualified path to the .ckpt file.
ckpt_path: null

# Finetuning configuration:
finetune: false
# Num of epochs to freeze perceiver network while finetuning
# -1 => Keep frozen, i.e. perform Unit-identification
# 0 => Train everything
# >0 => Only train unit/session embeddings for first few epochs,
# and then train everything
freeze_perceiver_until_epoch: 0
39 changes: 12 additions & 27 deletions examples/poyo/configs/train_allen_neuropixels.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Path: configs/train.yaml
defaults:
- _self_
- base.yaml
- model: poyo_single_session.yaml
- dataset: allen_neuropixels.yaml
- _self_

hydra:
searchpath:
Expand All @@ -18,31 +18,16 @@ train_transforms:
crop_len: 1.0

data_root: /kirby/processed/allen_all/
seed: 42

batch_size: 128
eval_epochs: 10
steps: 0 # Note we either specify epochs or steps, not both.
epochs: 1000
base_lr: 1.5625e-5
weight_decay: 0.0001
# Fraction of epochs to warmup for.
pct_start: 0.5
num_workers: 4
log_dir: ./logs
name: allen_neuropixels
backend_config: gpu_fp16
precision: bf16-mixed
nodes: 1
gpus: 1
# Where to resume/finetune from. Could be null (yaml for None, meaning train from
# scratch) or a fully qualified path to the .ckpt file.
ckpt_path: null

# Finetuning configuration:
finetune: false
# Num of epochs to freeze perceiver network while finetuning
# -1 => Keep frozen, i.e. perform Unit-identification
# 0 => Train everything
# >0 => Only train unit/session embeddings for first few epochs,
# and then train everything
freeze_perceiver_until_epoch: 0
optim:
base_lr: 1.5625e-5
weight_decay: 0.0001

wandb:
run_name: allen_neuropixels

backend_config: gpu_fp16
precision: bf16-mixed
36 changes: 8 additions & 28 deletions examples/poyo/configs/train_mc_maze_small.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Path: configs/train.yaml
defaults:
- _self_
- base.yaml
- model: poyo_single_session.yaml
- dataset: mc_maze_small.yaml
- _self_

train_transforms:
- _target_: kirby.transforms.UnitDropout
Expand All @@ -11,32 +11,12 @@ train_transforms:
mode_units: 300
peak: 4

data_root: ./data/processed/
seed: 42
batch_size: 128
eval_epochs: 10
epochs: 1000
steps: 0 # Note we either specify epochs or steps, not both.
base_lr: 1.5625e-5
weight_decay: 0.0001
# Fraction of epochs to warmup for.
pct_start: 0.5
num_workers: 4
log_dir: ./logs
name: mc_maze_small
backend_config: gpu_fp32
precision: 32
nodes: 1
gpus: 1
# Where to resume/finetune from. Could be null (yaml for None, meaning train from
# scratch) or a fully qualified path to the .ckpt file.
ckpt_path: null

# Finetuning configuration:
finetune: false
# Num of epochs to freeze perceiver network while finetuning
# -1 => Keep frozen, i.e. perform Unit-identification
# 0 => Train everything
# >0 => Only train unit/session embeddings for first few epochs,
# and then train everything
freeze_perceiver_until_epoch: 0
optim:
base_lr: 1.5625e-5
weight_decay: 0.0001

wandb:
run_name: mc_maze_small
35 changes: 4 additions & 31 deletions examples/poyo/configs/train_poyo_mp.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Path: configs/train.yaml
defaults:
- _self_
- base.yaml
- model: poyo_single_session.yaml
- dataset: perich_miller_population_2018.yaml
- _self_

train_transforms:
- _target_: torch_brain.transforms.UnitDropout
Expand All @@ -11,32 +11,5 @@ train_transforms:
mode_units: 300
peak: 4

data_root: ./data/processed/
seed: 42
batch_size: 128
eval_epochs: 1
epochs: 1000
steps: 0 # Note we either specify epochs or steps, not both.
base_lr: 3.125e-5
weight_decay: 1e-4
# Fraction of epochs to warmup for.
pct_start: 0.5
num_workers: 4
log_dir: ./logs
name: poyo_mp_mini
backend_config: gpu_fp32
precision: 32
nodes: 1
gpus: 1
# Where to resume/finetune from. Could be null (yaml for None, meaning train from
# scratch) or a fully qualified path to the .ckpt file.
ckpt_path: null

# Finetuning configuration:
finetune: false
# Num of epochs to freeze perceiver network while finetuning
# -1 => Keep frozen, i.e. perform Unit-identification
# 0 => Train everything
# >0 => Only train unit/session embeddings for first few epochs,
# and then train everything
freeze_perceiver_until_epoch: 0
wandb:
run_name: poyo_mp_mini
27 changes: 13 additions & 14 deletions examples/poyo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,17 @@ def run_training(cfg: DictConfig):
val_dataset,
sampler=val_sampler,
collate_fn=collate,
batch_size=cfg.get(
"eval_batch_size", cfg.batch_size
), # Default to training batch size, but allow override in config.
batch_size=cfg.eval_batch_size or cfg.batch_size,
num_workers=2,
)

# Update config with dynamic data
with open_dict(cfg):
cfg.steps_per_epoch = len(train_loader)

if cfg.epochs > 0 and cfg.steps == 0:
if cfg.epochs > 0:
cfg.epochs = cfg.epochs
elif cfg.steps > 0 and cfg.epochs == 0:
elif cfg.steps > 0:
cfg.epochs = cfg.steps // cfg.steps_per_epoch + 1
cfg.steps = 0
log.info(f"Setting epochs to {cfg.epochs} using cfg.steps = {cfg.steps}")
Expand All @@ -183,14 +181,15 @@ def run_training(cfg: DictConfig):
dataset_config_dict=train_dataset.get_session_config_dict(),
)

wandb = lightning.pytorch.loggers.WandbLogger(
save_dir=cfg.log_dir,
entity=cfg.get("wandb_entity", None),
name=cfg.name,
project=cfg.get("wandb_project", "poyo"),
log_model=cfg.get("wandb_log_model", False),
)
print(f"Wandb ID: {wandb.version}")
wandb_logger = None
if cfg.wandb.enable:
wandb_logger = lightning.pytorch.loggers.WandbLogger(
save_dir=cfg.log_dir,
entity=cfg.wandb.entity,
name=cfg.wandb.run_name,
project=cfg.wandb.project,
log_model=cfg.wandb.log_model,
)

callbacks = [
ModelSummary(max_depth=2), # Displays the number of parameters in the model.
Expand All @@ -212,7 +211,7 @@ def run_training(cfg: DictConfig):
raise NotImplementedError("This functionality isn't properly implemented.")

trainer = lightning.Trainer(
logger=wandb,
logger=wandb_logger,
default_root_dir=cfg.log_dir,
check_val_every_n_epoch=cfg.eval_epochs,
max_epochs=cfg.epochs,
Expand Down
6 changes: 3 additions & 3 deletions examples/poyo/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ def __init__(
self.save_hyperparameters(OmegaConf.to_container(cfg))

def configure_optimizers(self):
max_lr = self.cfg.base_lr * self.cfg.batch_size # linear scaling rule
max_lr = self.cfg.optim.base_lr * self.cfg.batch_size # linear scaling rule

optimizer = Lamb(
self.model.parameters(),
lr=max_lr,
weight_decay=self.cfg.weight_decay,
weight_decay=self.cfg.optim.weight_decay,
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=max_lr,
epochs=self.cfg.epochs,
steps_per_epoch=self.cfg.steps_per_epoch,
pct_start=self.cfg.pct_start,
pct_start=self.cfg.optim.lr_decay_start,
anneal_strategy="cos",
div_factor=1,
)
Expand Down
Loading