Skip to content

Commit

Permalink
Edits to bypass basic errors (#3)
Browse files Browse the repository at this point in the history
* Edits to bypass compile errors

* Remove `.cpu()` in model weight stats logging
  • Loading branch information
vinamarora8 authored Oct 11, 2024
1 parent 3a90a97 commit ef946e7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
3 changes: 3 additions & 0 deletions examples/poyo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def run_training(cfg: DictConfig):
include=OmegaConf.to_container(cfg.dataset), # converts to native list[dicts]
transform=transform,
)
train_dataset.disable_data_leakage_check()

# In Lightning, testing only happens once, at the end of training. To get the
# intended behavior, we need to specify a validation set instead.
val_tokenizer = copy.copy(tokenizer)
Expand All @@ -84,6 +86,7 @@ def run_training(cfg: DictConfig):
include=OmegaConf.to_container(cfg.dataset), # converts to native list[dicts]
transform=val_tokenizer,
)
val_dataset.disable_data_leakage_check()

if not cfg.finetune:
# Register units and sessions
Expand Down
2 changes: 1 addition & 1 deletion torch_brain/models/capoyo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from einops import rearrange, repeat

from brainsets.taxonomy import DecoderSpec, Decoder
from brainsets.taxonomy.mice import Cre_line, Depth_classes
from brainsets.taxonomy.mice import Cre_line
from torch_brain.nn import (
Embedding,
InfiniteVocabEmbedding,
Expand Down
6 changes: 3 additions & 3 deletions torch_brain/utils/train_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def training_step(self, data, data_idx):

def on_train_epoch_end(self):
for tag, value in self.model.named_parameters():
self.log(f"weights/mean_{tag}", value.cpu().mean(), sync_dist=True)
self.log(f"weights/std_{tag}", value.cpu().std(), sync_dist=True)
self.log(f"weights/mean_{tag}", value.mean(), sync_dist=True)
self.log(f"weights/std_{tag}", value.std(), sync_dist=True)
if value.grad is not None:
self.log(
f"grads/mean_{tag}",
value.grad.cpu().mean(),
value.grad.mean(),
sync_dist=True,
)

Expand Down

0 comments on commit ef946e7

Please sign in to comment.