Skip to content

Commit

Permalink
add sweeping and sharded training
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Oct 31, 2024
1 parent bd33ddd commit fcb77f7
Show file tree
Hide file tree
Showing 10 changed files with 410 additions and 226 deletions.
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Contributing

Contributions are welcome.
This document outlines some programming conventions that are not caught by automated tools.
## Variable Names

Variables referring to a filepath should be suffixed with `_fpath`.
1 change: 1 addition & 0 deletions configs/baseline.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lr = [5e-5, 1e-4, 2e-4, 4e-4]
5 changes: 5 additions & 0 deletions logbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,9 @@ Then it's 5TB.
Still too big.
But I can debug these processes on the lab servers.

With a ViT-B/32, saving the last 3 layers, ImagetNet-1K is

1.2M x 3 x 50 x 768 x 4 bytes/float = 553GB

It seems that training is working well.
I can train on 100M patches in about 40m on an A6000, which is good because it's 10x more tokens than 10M and about 10x slower (40m vs 4m).
52 changes: 49 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import logging
import tomllib
import typing

import tyro
Expand All @@ -7,17 +9,18 @@

log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)

logger = logging.getLogger("main")


def activations(
cfg: typing.Annotated[saev.ActivationsConfig, tyro.conf.arg(name="")],
):
"""
Save vision model activations for use later on.
Save ViT activations for use later on.
Args:
cfg: Configuration for dumping activations.
cfg: Configuration for activations.
"""
if not cfg.ssl:
logger.warning("Ignoring SSL certs. Try not to do this!")
Expand All @@ -32,6 +35,48 @@ def activations(
saev.activations.dump(cfg)


def sweep(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")], sweep: str):
"""
Run a grid search over a set of hyperparameters.
Args:
cfg: Baseline config for training an SAE.
sweep: Path to .toml file defining the sweep parameters.
"""
import submitit
import saev.sweep
import saev.training

with open(sweep, "rb") as fd:
dcts = list(saev.sweep.expand(tomllib.load(fd)))
logger.info("Sweep has %d experiments.", len(dcts))

sweep_cfgs, errs = [], []
for dct in dcts:
try:
sweep_cfgs.append(dataclasses.replace(cfg, **dct))
except Exception as err:
errs.append(str(err))

if cfg.slurm:
executor = submitit.SlurmExecutor(folder=cfg.log_to)
executor.update_parameters(
time=30,
partition="debug",
gpus_per_node=1,
cpus_per_task=cfg.n_workers + 4,
stderr_to_stdout=True,
account=cfg.slurm_acct,
)
else:
executor = submitit.DebugExecutor(folder=cfg.log_to)

jobs = executor.map_array(saev.training.train, sweep_cfgs)
for i, result in enumerate(submitit.helpers.as_completed(jobs)):
exp_id = result.result()
logger.info("Finished task %s (%d/%d)", exp_id, i + 1, len(jobs))


def train(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")]):
import submitit

Expand Down Expand Up @@ -106,7 +151,8 @@ def fn():
if __name__ == "__main__":
tyro.extras.subcommand_cli_from_dict({
"activations": activations,
"train": train,
"sweep": sweep,
# "train": train,
# "analysis": analysis,
# "webapp": webapp,
})
Expand Down
180 changes: 125 additions & 55 deletions saev/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one.
This module handles that additional complexity.
Conceptually, activations are either thought of as
1. A single [n_imgs x n_layers x (n_patches + 1), d_vit] tensor. This is a *dataset*
2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations.
"""

import json
import dataclasses
import hashlib
import json
import logging
import os
import shutil
Expand All @@ -30,7 +35,7 @@ class RecordedVit(torch.nn.Module):
n_layers: int
n_patches: int
model: torch.nn.Module
_storage: Float[Tensor, "batch n_layers n_patches+1 dim"] | None
_storage: Float[Tensor, "batch n_layers all_patches dim"] | None

def __init__(self, cfg: config.Activations):
super().__init__()
Expand Down Expand Up @@ -61,8 +66,8 @@ def __init__(self, cfg: config.Activations):
assert w % pw == 0
assert h % ph == 0
self.n_patches = (w // pw) * (h // pw)
msg = f"ViT has {self.n_patches} patches; config has {cfg.n_patches} n_patches."
assert self.n_patches == cfg.n_patches, msg
msg = f"ViT has {self.n_patches} patches; config has {cfg.n_patches_per_img} n_patches."
assert self.n_patches == cfg.n_patches_per_img, msg
self.n_layers = cfg.n_layers

self.logger = logging.getLogger("recorder")
Expand Down Expand Up @@ -104,42 +109,64 @@ def reset(self):
self._i = 0

@property
def activations(self) -> Float[Tensor, "batch n_layers n_patches+1 dim"]:
def activations(self) -> Float[Tensor, "batch n_layers all_patches dim"]:
if self._storage is None:
raise RuntimeError("First call model()")
return self._storage.cpu()


@jaxtyped(typechecker=beartype.beartype)
class ShardedMmapTensor(torch.utils.data.Dataset):
cfg: config.Activations
class Dataset(torch.utils.data.Dataset):
root: str
metadata: "Metadata"

def __init__(self, cfg: config.Activations):
self.cfg = cfg
self._len = cfg.data.n_imgs
self.root = get_acts_dir(cfg)

if not os.path.isdir(self._root):
raise RuntimeError(f"Activations are not saved at '{self._root}'.")
def __init__(self, root: str):
self.root = root
if not os.path.isdir(self.root):
raise RuntimeError(f"Activations are not saved at '{self.root}'.")

self._length = cfg.data.n_imgs
self.metadata = Metadata.load(os.path.join(root, "metadata.json"))

@typing.overload
def __getitem__(self, i: int) -> tuple[Float[Tensor, " d_model"], Int[Tensor, ""]]:
breakpoint()
shard = i // self.cfg.n_per_shard
# for meta in self.metas:
# if meta.lower <= i <= meta.upper:
# acts = np.memmap(
# meta.path, mode="r", dtype=np.float32, shape=meta.flattened
# )
# return torch.from_numpy(acts[i]), torch.tensor(i)

# raise IndexError(f"Index {i} not in range [0, {len(self)}).")
@property
def d_vit(self) -> int:
return self.metadata.d_vit

@jaxtyped(typechecker=beartype.beartype)
def __getitem__(
self, i: Int[np.ndarray, "*batch"] | int
) -> tuple[Float[Tensor, "*batch d_vit"], Int[Tensor, "*batch"]]:
if isinstance(i, int):
shard = i // self.metadata.n_patches_per_shard
pos = i % self.metadata.n_patches_per_shard
acts_fpath = os.path.join(self.root, f"acts{shard:06}.bin")
acts = np.memmap(
acts_fpath,
mode="c",
dtype=np.float32,
shape=(self.metadata.n_patches_per_shard, self.metadata.d_vit),
)
return torch.from_numpy(acts[pos].copy()), torch.tensor(i)
else:
shards = i // self.metadata.n_patches_per_shard
pos = i % self.metadata.n_patches_per_shard
batch = []
for shard, p in zip(shards, pos):
acts_fpath = os.path.join(self.root, f"acts{shard.item():06}.bin")
acts = np.memmap(
acts_fpath,
mode="c",
dtype=np.float32,
shape=(self.metadata.n_patches_per_shard, self.metadata.d_vit),
)
batch.append(torch.from_numpy(acts[p].copy()))
return torch.stack(batch), torch.tensor(i)

def __len__(self) -> int:
return self._length
return (
self.metadata.n_imgs
* self.metadata.n_layers
* (self.metadata.n_patches_per_img + 1)
)


@beartype.beartype
Expand Down Expand Up @@ -454,6 +481,7 @@ def worker_fn(cfg: config.Activations):
cfg = dataclasses.replace(cfg, device="cpu")

vit = vit.to(cfg.device)
vit = torch.compile(vit)

i = 0
# Calculate and write ViT activations.
Expand All @@ -480,32 +508,41 @@ class ShardWriter:
shape: tuple[int, int, int, int]
shard: int
acts_path: str
acts: Float[np.ndarray, "n n_layers n_patches+1 d_vit"] | None
acts: Float[np.ndarray, "n_imgs_per_shard n_layers all_patches d_vit"] | None
filled: int

def __init__(self, cfg: config.Activations):
self.logger = logging.getLogger("shard-writer")

self.root = get_acts_dir(cfg)
os.makedirs(self.root, exist_ok=True)

self.n_per_shard = cfg.n_per_shard // cfg.n_layers // (cfg.n_patches + 1)
self.shape = (self.n_per_shard, cfg.n_layers, cfg.n_patches + 1, cfg.d_vit)
self.n_imgs_per_shard = (
cfg.n_patches_per_shard // cfg.n_layers // (cfg.n_patches_per_img + 1)
)
self.shape = (
self.n_imgs_per_shard,
cfg.n_layers,
cfg.n_patches_per_img + 1,
cfg.d_vit,
)

self.shard = -1
self.acts = None
self.next_shard()

def __setitem__(self, i: slice, val) -> None:
@jaxtyped(typechecker=beartype.beartype)
def __setitem__(
self, i: slice, val: Float[Tensor, "_ n_layers all_patches d_vit"]
) -> None:
assert i.step is None
a, b = i.start, i.stop
assert len(val) == b - a

offset = self.n_per_shard * self.shard
offset = self.n_imgs_per_shard * self.shard

if b >= offset + self.n_per_shard:
if b >= offset + self.n_imgs_per_shard:
# We have run out of space in this mmap'ed file. Let's fill it as much as we can.
n_fit = offset + self.n_per_shard - a
n_fit = offset + self.n_imgs_per_shard - a
self.acts[a - offset : a - offset + n_fit] = val[:n_fit]
self.filled = a - offset + n_fit

Expand All @@ -514,10 +551,10 @@ def __setitem__(self, i: slice, val) -> None:
# Recursively call __setitem__ in case we need *another* shard
self[a + n_fit : b] = val[n_fit:]
else:
msg = f"0 <= {a} - {offset} <= {offset} + {self.n_per_shard}"
assert 0 <= a - offset <= offset + self.n_per_shard, msg
msg = f"0 <= {b} - {offset} <= {offset} + {self.n_per_shard}"
assert 0 <= b - offset <= offset + self.n_per_shard, msg
msg = f"0 <= {a} - {offset} <= {offset} + {self.n_imgs_per_shard}"
assert 0 <= a - offset <= offset + self.n_imgs_per_shard, msg
msg = f"0 <= {b} - {offset} <= {offset} + {self.n_imgs_per_shard}"
assert 0 <= b - offset <= offset + self.n_imgs_per_shard, msg
self.acts[a - offset : b - offset] = val
self.filled = b - offset

Expand All @@ -541,6 +578,50 @@ def next_shard(self) -> None:
self.logger.info("Opened shard '%s'.", self.acts_path)


@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Metadata:
width: int
height: int
model: str
n_layers: int
n_patches_per_img: int
d_vit: int
seed: int
n_imgs: int
n_patches_per_shard: int
data: str

@classmethod
def from_cfg(cls, cfg: config.Activations) -> "Metadata":
return cls(
cfg.width,
cfg.height,
cfg.model,
cfg.n_layers,
cfg.n_patches_per_img,
cfg.d_vit,
cfg.seed,
cfg.data.n_imgs,
cfg.n_patches_per_shard,
str(cfg.data),
)

@classmethod
def load(cls, fpath) -> "Metadata":
with open(fpath) as fd:
return cls(**json.load(fd))

def dump(self, fpath):
with open(fpath, "w") as fd:
json.dump(dataclasses.asdict(self), fd, indent=4)

@property
def hash(self) -> str:
cfg_str = json.dumps(dataclasses.asdict(self), sort_keys=True)
return hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()


@beartype.beartype
def get_acts_dir(cfg: config.Activations) -> str:
"""
Expand All @@ -553,22 +634,11 @@ def get_acts_dir(cfg: config.Activations) -> str:
Returns:
Directory to where activations should be dumped/loaded from.
"""
metadata = {
"width": cfg.width,
"height": cfg.height,
"model": cfg.model,
"data": str(cfg.data),
"n_per_shard": cfg.n_per_shard,
"n_layers": cfg.n_layers,
"seed": cfg.seed,
}

cfg_str = json.dumps(metadata, sort_keys=True)
acts_hash = hashlib.sha256(cfg_str.encode("utf-8")).hexdigest()
acts_dir = os.path.join(helpers.get_cache_dir(), "saev-acts", acts_hash)
metadata = Metadata.from_cfg(cfg)

acts_dir = os.path.join(helpers.get_cache_dir(), "saev-acts", metadata.hash)
os.makedirs(acts_dir, exist_ok=True)

with open(os.path.join(acts_dir, "metadata.json"), "w") as fd:
json.dump(metadata, fd)
metadata.dump(os.path.join(acts_dir, "metadata.json"))

return acts_dir
Loading

0 comments on commit fcb77f7

Please sign in to comment.