Skip to content

Commit

Permalink
Add some minor speedups
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Nov 8, 2024
1 parent 84e29c3 commit 20cf631
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
15 changes: 8 additions & 7 deletions saev/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __getitem__(self, i: int) -> Example:
# Select layer's cls token.
act = img_act[self.cfg.layer, 0, :]
return self.Example(
torch.from_numpy(act), torch.tensor(i), torch.tensor(-1)
torch.from_numpy(act.copy()), torch.tensor(i), torch.tensor(-1)
)
case ("cls", "meanpool"):
img_act = self.get_img_patches(i)
Expand All @@ -251,7 +251,7 @@ def __getitem__(self, i: int) -> Example:
# Meanpool over the layers
act = cls_act.mean(axis=0)
return self.Example(
torch.from_numpy(act), torch.tensor(i), torch.tensor(-1)
torch.from_numpy(act.copy()), torch.tensor(i), torch.tensor(-1)
)
case ("meanpool", int()):
img_act = self.get_img_patches(i)
Expand All @@ -260,7 +260,7 @@ def __getitem__(self, i: int) -> Example:
# Meanpool over the patches
act = layer_act.mean(axis=0)
return self.Example(
torch.from_numpy(act), torch.tensor(i), torch.tensor(-1)
torch.from_numpy(act.copy()), torch.tensor(i), torch.tensor(-1)
)
case ("meanpool", "meanpool"):
img_act = self.get_img_patches(i)
Expand All @@ -269,7 +269,7 @@ def __getitem__(self, i: int) -> Example:
# Meanpool over the layers and patches
act = act.mean(axis=(0, 1))
return self.Example(
torch.from_numpy(act), torch.tensor(i), torch.tensor(-1)
torch.from_numpy(act.copy()), torch.tensor(i), torch.tensor(-1)
)
case ("patches", int()):
n_imgs_per_shard = (
Expand Down Expand Up @@ -299,9 +299,9 @@ def __getitem__(self, i: int) -> Example:
act = acts[
pos // self.metadata.n_patches_per_img,
pos % self.metadata.n_patches_per_img,
].copy()
]
return self.Example(
torch.from_numpy(act),
torch.from_numpy(act.copy()),
# What image is this?
torch.tensor(i // self.metadata.n_patches_per_img),
torch.tensor(i % self.metadata.n_patches_per_img),
Expand Down Expand Up @@ -331,7 +331,8 @@ def get_img_patches(
self.metadata.d_vit,
)
acts = np.memmap(acts_fpath, mode="c", dtype=np.float32, shape=shape)
return acts[pos].copy()
# Note that this is not yet copied!
return acts[pos]

def __len__(self) -> int:
"""
Expand Down
4 changes: 2 additions & 2 deletions saev/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def get_sae_batches(
perm = np.random.default_rng(seed=cfg.seed).permutation(len(dataset))
perm = perm[: cfg.reinit_size]

examples, _ = dataset[perm]
examples, _, _ = zip(*[dataset[p.item()] for p in perm])

return examples
return torch.stack(examples)


@beartype.beartype
Expand Down
4 changes: 2 additions & 2 deletions saev/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def train(cfg: config.Train) -> str:

global_step, n_patches_seen = 0, 0

for vit_acts, _ in helpers.progress(dataloader, every=cfg.log_every):
vit_acts = vit_acts.to(cfg.device)
for vit_acts, _, _ in helpers.progress(dataloader, every=cfg.log_every):
vit_acts = vit_acts.to(cfg.device, non_blocking=True)
# Make sure the W_dec is still zero-norm
sae.set_decoder_norm_to_unit_norm()

Expand Down

0 comments on commit 20cf631

Please sign in to comment.