diff --git a/configs/preprint/classification.toml b/configs/preprint/classification.toml new file mode 100644 index 0000000..85a6db5 --- /dev/null +++ b/configs/preprint/classification.toml @@ -0,0 +1,17 @@ +tag = "classification-v1.0" + +lr = [1e-4, 3e-4, 1e-3, 3e-3] + +n_lr_warmup = 500 +n_sparsity_warmup = 500 + +[sae] +sparsity_coeff = [4e-4, 8e-4, 1.6e-3] +ghost_grads = false +normalize_w_dec = true +remove_parallel_grads = true +exp_factor = [16, 32] + +[data] +scale_mean = true +scale_norm = true diff --git a/saev/webapp.py b/contrib/classification/__init__.py similarity index 100% rename from saev/webapp.py rename to contrib/classification/__init__.py diff --git a/contrib/classification/reproduce.md b/contrib/classification/reproduce.md new file mode 100644 index 0000000..9dec4ae --- /dev/null +++ b/contrib/classification/reproduce.md @@ -0,0 +1,30 @@ +# Reproduce + +You can reproduce our classification control experiments from our preprint by following these instructions. + +The big overview (as described in our paper) is: + +1. Train an SAE on the ImageNet-1K [CLS] token activations from a CLIP ViT-B/16, from the 11th (second-to-last) layer. +2. Show that you get meaningful features, through visualizations. +3. Train a linear probe on the [CLS] token activations from a CLIP ViT-B/16, from the 11th layer, on the Oxford Flowers-102 dataset. +4. Show that we get good accuracy. +5. Manipulate the activations using the proposed SAE features. +6. Be amazed. :) + +To do these steps: + +## Record ImageNet-1K activations + +## Train an SAE + +```sh +uv run python -m saev train --sweep configs/preprint/classification.toml --data.shard-root /local/scratch/stevens.994/cache/saev/ac89246f1934b45e2f0487298aebe36ad998b6bd252d880c0c9ec5de78d793c8/ --data.patches cls --data.layer -2 --sae.d-vit 768 +``` + +## Visualize the SAE Features + +## Record Oxford Flowers-102 Activations + +## Train a Linear Probe + +## Manipulate diff --git a/docs/llms.txt b/docs/llms.txt index 65e53f7..9cd6ae8 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -140,18 +140,19 @@ You can run it with `uv run marimo edit saev/webapp.py`. Sub-modules ----------- * saev.activations +* saev.app * saev.config * saev.helpers * saev.imaging +* saev.interactive * saev.nn * saev.test_activations * saev.test_config * saev.test_nn -* saev.test_webapp +* saev.test_training +* saev.test_visuals * saev.training * saev.visuals -* saev.web -* saev.webapp Module saev.activations ======================= @@ -168,7 +169,7 @@ Functions --------- `get_acts_dir(cfg: saev.config.Activations) ‑> str` -: Return the activations filepath based on the relevant values of a config. +: Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference. Args: @@ -479,6 +480,57 @@ Classes * torch.utils.data.dataset.Dataset * typing.Generic +`MaskedAutoencoder(cfg: saev.config.Activations)` +: Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing to nest them in + a tree structure. You can assign the submodules as regular attributes:: + + import torch.nn as nn + import torch.nn.functional as F + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + Submodules assigned in this way will be registered, and will have their + parameters converted too when you call :meth:`to`, etc. + + .. note:: + As per the example above, an ``__init__()`` call to the parent class + must be made before assignment on the child. + + :ivar training: Boolean represents whether this module is in training or + evaluation mode. + :vartype training: bool + + Initialize internal Module state, shared by both nn.Module and ScriptModule. + + ### Ancestors (in MRO) + + * torch.nn.modules.module.Module + + ### Methods + + `forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any]` + : Define the computation performed at every call. + + Should be overridden by all subclasses. + + .. note:: + Although the recipe for forward pass needs to be defined within + this function, one should call the :class:`Module` instance afterwards + instead of this since the former takes care of running the + registered hooks while the latter silently ignores them. + `Metadata(model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str)` : Metadata(model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str) @@ -673,6 +725,9 @@ Classes `reset(self)` : +Module saev.app +=============== + Module saev.config ================== All configs for all saev jobs. @@ -992,6 +1047,12 @@ Useful helpers for `saev`. Functions --------- +`flattened(dct: dict[str, object], *, sep: str = '.') ‑> dict[str, str | int | float | bool | None]` +: Flatten a potentially nested dict to a single-level dict with `.`-separated keys. + +`get(dct: dict[str, object], key: str, *, sep: str = '.') ‑> object` +: + `get_cache_dir() ‑> str` : Get cache directory from environment variables, defaulting to the current working directory (.) @@ -1019,6 +1080,20 @@ Functions `add_highlights(img: PIL.Image.Image, patches: jaxtyping.Float[ndarray, 'n_patches'], *, upper: float | None = None) ‑> PIL.Image.Image` : +Namespace saev.interactive +========================== + +Sub-modules +----------- +* saev.interactive.features +* saev.interactive.metrics + +Module saev.interactive.features +================================ + +Module saev.interactive.metrics +=============================== + Module saev.nn ============== Neural network architectures for sparse autoencoders. @@ -1164,8 +1239,26 @@ Functions `test_safe_mse_zero_x_hat()` : -Module saev.test_webapp -======================= +Module saev.test_training +========================= + +Functions +--------- + +`test_split_cfgs_no_bad_keys()` +: + +`test_split_cfgs_on_multiple_keys_with_multiple_per_key()` +: + +`test_split_cfgs_on_single_key()` +: + +`test_split_cfgs_on_single_key_with_multiple_per_key()` +: + +Module saev.test_visuals +======================== Functions --------- @@ -1180,9 +1273,6 @@ Trains many SAEs in parallel to amortize the cost of loading a single batch of d Functions --------- -`check_cfgs(cfgs: list[saev.config.Train])` -: - `evaluate(cfgs: list[saev.config.Train], saes: torch.nn.modules.container.ModuleList) ‑> list[saev.training.EvalMetrics]` : Evaluates SAE quality by counting the number of dead features and the number of dense features. Also makes histogram plots to help human qualitative comparison. @@ -1195,9 +1285,21 @@ Functions `main(cfgs: list[saev.config.Train]) ‑> list[str]` : +`make_hashable(obj)` +: + `make_saes(cfgs: list[saev.config.SparseAutoencoder]) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]` : +`split_cfgs(cfgs: list[saev.config.Train]) ‑> list[list[saev.config.Train]]` +: Splits configs into groups that can be parallelized. + + Arguments: + A list of configs from a sweep file. + + Returns: + A list of lists, where the configs in each sublist do not differ in any keys that are in `CANNOT_PARALLELIZE`. This means that each sublist is a valid "parallel" set of configs for `train`. + `train(cfgs: list[saev.config.Train]) ‑> tuple[torch.nn.modules.container.ModuleList, saev.training.ParallelWandbRun, int]` : Explicitly declare the optimizer, schedulers, dataloader, etc outside of `main` so that all the variables are dropped from scope and can be garbage collected. @@ -1379,25 +1481,92 @@ Classes `patches: jaxtyping.Float[Tensor, 'n_patches']` : -Namespace saev.web -================== +Namespace contrib +================= Sub-modules ----------- -* saev.web.probing +* contrib.classification +* contrib.mae +* contrib.semseg -Module saev.web.probing -======================= +Module contrib.classification +============================= -Module saev.webapp -================== +Sub-modules +----------- +* contrib.classification.app +* contrib.classification.config +* contrib.classification.probing -Namespace contrib -================= +Module contrib.classification.app +================================= + +Module contrib.classification.config +==================================== + +Module contrib.classification.probing +===================================== + +Namespace contrib.mae +===================== Sub-modules ----------- -* contrib.semseg +* contrib.mae.modeling + +Module contrib.mae.modeling +=========================== + +Functions +--------- + +`download()` +: + +`load_model(ckpt: str) ‑> contrib.mae.modeling.MaskedAutoencoder` +: + +Classes +------- + +`MaskedAutoencoder()` +: Base class for all neural network modules. + + Your models should also subclass this class. + + Modules can also contain other Modules, allowing to nest them in + a tree structure. You can assign the submodules as regular attributes:: + + import torch.nn as nn + import torch.nn.functional as F + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = nn.Conv2d(1, 20, 5) + self.conv2 = nn.Conv2d(20, 20, 5) + + def forward(self, x): + x = F.relu(self.conv1(x)) + return F.relu(self.conv2(x)) + + Submodules assigned in this way will be registered, and will have their + parameters converted too when you call :meth:`to`, etc. + + .. note:: + As per the example above, an ``__init__()`` call to the parent class + must be made before assignment on the child. + + :ivar training: Boolean represents whether this module is in training or + evaluation mode. + :vartype training: bool + + Initialize internal Module state, shared by both nn.Module and ScriptModule. + + ### Ancestors (in MRO) + + * torch.nn.modules.module.Module Module contrib.semseg ===================== @@ -1407,6 +1576,7 @@ This sub-module reproduces the results from Section 4.2 of our paper. As an overview: +0. Record ViT activations for ADE20K. 1. Train an SAE on activations. 2. Train a linear probe on semantic segmentation task using ADE20K. 3. Establish baseline metrics for the linear probe. @@ -1416,20 +1586,75 @@ As an overview: Details can be found below. +# Record ViT Activations for SAE and Linear Probe Training + # Train an SAE on ViT Activations # Train a Linear Probe on Semantic Segmentation +Now train a linear probe on the activations. + +```sh +uv run python -m contrib.semseg train \ + --train-acts.shard-root /local/scratch/stevens.994/cache/saev/a860104bf29d6093dd18b8e2dccd2e7efdfcd9fac35dceb932795af05187cb9f/ \ + --train-acts.no-scale-mean \ + --train-acts.no-scale-norm \ + --val-acts.shard-root /local/scratch/stevens.994/cache/saev/c6756186d1490ac69fab6f8efb883a1c59d44d0594d99397051bfe8e409ca91d/ \ + --val-acts.no-scale-mean \ + --val-acts.no-scale-norm \ + --imgs.root /research/nfs_su_809/workspace/stevens.994/datasets/ade20k/ \ + --sweep contrib/semseg/sweep.toml +``` + # Establish Linear Probe Baseline Metrics # Identify Class-Specific Feature Vectors in the SAE +```sh +uv run python -m contrib.semseg visuals \ + --sae-ckpt checkpoints/ercgckr1/sae.pt \ + --acts.shard-root /local/scratch/stevens.994/cache/saev/a860104bf29d6093dd18b8e2dccd2e7efdfcd9fac35dceb932795af05187cb9f/ \ + --acts.no-scale-mean \ + --acts.no-scale-norm \ + --imgs.root /research/nfs_su_809/workspace/stevens.994/datasets/ade20k/ +``` + +This will tell you to run a particular command in order to generate visuals for use with `saev.interactive.features`. + +# Manipulate ViT Activations + +```sh +uv run python -m contrib.semseg manipulate \ + --acts.shard-root /local/scratch/stevens.994/cache/saev/c6756186d1490ac69fab6f8efb883a1c59d44d0594d99397051bfe8e409ca91d/ \ + --acts.no-scale-mean \ + --acts.no-scale-norm \ + --imgs.root /research/nfs_su_809/workspace/stevens.994/datasets/ade20k/ \ + --sae-ckpt checkpoints/ercgckr1/sae.pt \ + --probe-ckpt checkpoints/semseg/lr_0_003__wd_0_1/model_ep199_step4000.pt \ + --ade20k-classes 29 \ + --sae-latents 8541 5818 10230 +``` +# Dimension Key + +Throughout the code, variables are annotated with shape suffixes, as [recommended by Noam Shazeer](https://medium.com/@NoamShazeer/shape-suffixes-good-coding-style-f836e72e24fd). + +The key for these suffixes: + +* B: batch size +* W: width in patches (typically 14 or 16) +* H: height in patches (typically 14 or 16) +* D: ViT activation dimension (typically 768 or 1024) +* S: SAE latent dimension (768 x 16, etc) +* L: Number of latents being manipulated at once (typically 1-5 at a time) +* C: Number of classes in ADE20K (151) + Sub-modules ----------- * contrib.semseg.config * contrib.semseg.dashboard * contrib.semseg.dashboard2 * contrib.semseg.interactive +* contrib.semseg.manipulation * contrib.semseg.training * contrib.semseg.validation * contrib.semseg.visuals @@ -1446,8 +1671,40 @@ Functions Classes ------- -`Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 200, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , eval_every: int = 50, device: str = 'cuda', ckpt_path: str = './checkpoints/semseg', seed: int = 42, log_to: str = './logs')` -: Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 200, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , eval_every: int = 50, device: str = 'cuda', ckpt_path: str = './checkpoints/semseg', seed: int = 42, log_to: str = './logs') +`Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] = , sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda')` +: Manipulation(probe_ckpt: str = './checkpoints/semseg/lr_0_001__wd_0_1/model.pt', sae_ckpt: str = './checkpoints/abcdef/sae.pt', ade20k_classes: list[int] = , sae_latents: list[int] = , acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, device: str = 'cuda') + + ### Class variables + + `acts: saev.config.DataLoad` + : Configuration for the saved ADE20K validation ViT activations. + + `ade20k_classes: list[int]` + : One or more ADE20K classes to track. + + `batch_size: int` + : Batch size for both linear probe and SAE. + + `device: str` + : Hardware for linear probe and SAE inference. + + `imgs: saev.config.Ade20kDataset` + : Configuration for the ADE20K validation dataset. + + `n_workers: int` + : Number of dataloader workers. + + `probe_ckpt: str` + : Linear probe checkpoint. + + `sae_ckpt: str` + : SAE checkpoint. + + `sae_latents: list[int]` + : one or more SAE latents to manipulate. + +`Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/semseg', seed: int = 42, log_to: str = './logs')` +: Train(learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 400, batch_size: int = 1024, n_workers: int = 32, train_acts: saev.config.DataLoad = , val_acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , eval_every: int = 100, device: str = 'cuda', ckpt_path: str = './checkpoints/semseg', seed: int = 42, log_to: str = './logs') ### Class variables @@ -1564,6 +1821,31 @@ Sub-modules Module contrib.semseg.interactive.feature_analysis ================================================== +Module contrib.semseg.manipulation +================================== +Manipulate representations by increasing or decreasing the presence of a feature in a ViT activation, then use the linear probe for inference. + +Record class-specific scores before and after manipulation to see that you can directly manipulate abilities to complete downstream tasks. + +## Dimension Key + +* B: batch size +* W: width in patches (typically 14 or 16) +* H: height in patches (typically 14 or 16) +* D: ViT activation dimension (typically 768 or 1024) +* S: SAE latent dimension (768 x 16, etc) +* L: Number of latents being manipulated at once (typically 1-5 at a time) +* C: Number of classes in ADE20K (151) + +Functions +--------- + +`main(cfg: contrib.semseg.config.Manipulation)` +: + +`manipulate(cfg: contrib.semseg.config.Manipulation, sae: saev.nn.SparseAutoencoder, acts_BWHD: jaxtyping.Float[Tensor, 'batch width height d_vit']) ‑> tuple[jaxtyping.Float[Tensor, 'batch width height d_vit'], jaxtyping.Float[Tensor, 'batch width height d_vit']]` +: + Module contrib.semseg.training ============================== diff --git a/docs/saev/activations.html b/docs/saev/activations.html index 72018ec..e8bf5c4 100644 --- a/docs/saev/activations.html +++ b/docs/saev/activations.html @@ -46,7 +46,7 @@

Functions

def get_acts_dir(cfg: Activations) ‑> str
-

Return the activations filepath based on the relevant values of a config. +

Return the activations directory based on the relevant values of a config. Also saves a metadata.json file to that directory for human reference.

Args

@@ -930,6 +930,83 @@

Ancestors

  • typing.Generic
  • +
    +class MaskedAutoencoder +(cfg: Activations) +
    +
    +

    Base class for all neural network modules.

    +

    Your models should also subclass this class.

    +

    Modules can also contain other Modules, allowing to nest them in +a tree structure. You can assign the submodules as regular attributes::

    +
    import torch.nn as nn
    +import torch.nn.functional as F
    +
    +class Model(nn.Module):
    +    def __init__(self) -> None:
    +        super().__init__()
    +        self.conv1 = nn.Conv2d(1, 20, 5)
    +        self.conv2 = nn.Conv2d(20, 20, 5)
    +
    +    def forward(self, x):
    +        x = F.relu(self.conv1(x))
    +        return F.relu(self.conv2(x))
    +
    +

    Submodules assigned in this way will be registered, and will have their +parameters converted too when you call :meth:to, etc.

    +
    +

    Note

    +

    As per the example above, an __init__() call to the parent class +must be made before assignment on the child.

    +
    +

    :ivar training: Boolean represents whether this module is in training or +evaluation mode. +:vartype training: bool

    +

    Initialize internal Module state, shared by both nn.Module and ScriptModule.

    +
    + +Expand source code + +
    @jaxtyped(typechecker=beartype.beartype)
    +class MaskedAutoencoder(torch.nn.Module):
    +    def __init__(self, cfg: config.Activations):
    +        super().__init__()
    +        assert cfg.model_family == "mae"
    +
    +        import mae
    +
    +        self.model = mae.load_model(cfg.model_ckpt)
    +        self.recorder = VitRecorder(cfg).register(self.model.vit.encoder.layer)
    +
    +    def forward(self, batch: Float[Tensor, "batch 3 width height"]):
    +        self.recorder.reset()
    +
    +        y = self.model(batch)
    +
    +        return y, self.recorder.activations
    +
    +

    Ancestors

    +
      +
    • torch.nn.modules.module.Module
    • +
    +

    Methods

    +
    +
    +def forward(self, batch: jaxtyping.Float[Tensor, 'batch 3 width height']) ‑> Callable[..., Any] +
    +
    +

    Define the computation performed at every call.

    +

    Should be overridden by all subclasses.

    +
    +

    Note

    +

    Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

    +
    +
    +
    +
    class Metadata (model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str) @@ -1514,6 +1591,12 @@

    Imagenet

  • +

    MaskedAutoencoder

    + +
  • +
  • Metadata

    • cls_token
    • diff --git a/docs/saev/helpers.html b/docs/saev/helpers.html index 07a331d..ed4e6b2 100644 --- a/docs/saev/helpers.html +++ b/docs/saev/helpers.html @@ -36,6 +36,18 @@

      Module saev.helpers

      Functions

      +
      +def flattened(dct: dict[str, object], *, sep: str = '.') ‑> dict[str, str | int | float | bool | None] +
      +
      +

      Flatten a potentially nested dict to a single-level dict with .-separated keys.

      +
      +
      +def get(dct: dict[str, object], key: str, *, sep: str = '.') ‑> object +
      +
      +
      +
      def get_cache_dir() ‑> str
      @@ -139,6 +151,8 @@

      Args

    • Functions

    • diff --git a/docs/saev/index.html b/docs/saev/index.html index 6de1227..cdb907f 100644 --- a/docs/saev/index.html +++ b/docs/saev/index.html @@ -124,7 +124,7 @@

      Visualize the Learned Features

      This will record the top 128 patches, and then save the unique images among those top 128 patches for each feature in the trained SAE. It will cache these best activations to disk, then start saving images to visualize later on.

      -

      saev.webapp is a small web application based on marimo to interactively look at these images.

      +

      saev.webapp is a small web application based on marimo to interactively look at these images.

      You can run it with uv run marimo edit saev/webapp.py.

      Sweeps

      @@ -144,6 +144,10 @@

      Sub-modules

      To save lots of activations, we want to do things in parallel, with lots of slurm jobs, and save multiple files, rather than just one …

      +
      saev.app
      +
      +
      +
      saev.config

      All configs for all saev jobs …

      @@ -156,6 +160,10 @@

      Sub-modules

      +
      saev.interactive
      +
      +
      +
      saev.nn

      Neural network architectures for sparse autoencoders.

      @@ -173,7 +181,11 @@

      Sub-modules

      Uses hypothesis and hypothesis-torch to generate test cases to compare our …

      -
      saev.test_webapp
      +
      saev.test_training
      +
      +
      +
      +
      saev.test_visuals
      @@ -185,14 +197,6 @@

      Sub-modules

      There is some important notation used only in this file to dramatically shorten variable names …

      -
      saev.web
      -
      -
      -
      -
      saev.webapp
      -
      -
      -
      @@ -219,18 +223,19 @@

      Sub-modules

    • Sub-modules

    diff --git a/docs/saev/interactive/features.html b/docs/saev/interactive/features.html new file mode 100644 index 0000000..f518759 --- /dev/null +++ b/docs/saev/interactive/features.html @@ -0,0 +1,57 @@ + + + + + + +saev.interactive.features API documentation + + + + + + + + + + + + + +
    +
    +
    +

    Module saev.interactive.features

    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    +
    + +
    + + + diff --git a/docs/saev/interactive/index.html b/docs/saev/interactive/index.html new file mode 100644 index 0000000..117e2b8 --- /dev/null +++ b/docs/saev/interactive/index.html @@ -0,0 +1,74 @@ + + + + + + +saev.interactive API documentation + + + + + + + + + + + + + +
    + + +
    + + + diff --git a/docs/saev/webapp.html b/docs/saev/interactive/metrics.html similarity index 96% rename from docs/saev/webapp.html rename to docs/saev/interactive/metrics.html index 8e68ae0..481c032 100644 --- a/docs/saev/webapp.html +++ b/docs/saev/interactive/metrics.html @@ -4,7 +4,7 @@ -saev.webapp API documentation +saev.interactive.metrics API documentation @@ -24,7 +24,7 @@
    -

    Module saev.webapp

    +

    Module saev.interactive.metrics

    @@ -44,7 +44,7 @@

    Module saev.webapp

    diff --git a/docs/saev/test_training.html b/docs/saev/test_training.html new file mode 100644 index 0000000..930eb1c --- /dev/null +++ b/docs/saev/test_training.html @@ -0,0 +1,92 @@ + + + + + + +saev.test_training API documentation + + + + + + + + + + + + + +
    +
    +
    +

    Module saev.test_training

    +
    +
    +
    +
    +
    +
    +
    +
    +

    Functions

    +
    +
    +def test_split_cfgs_no_bad_keys() +
    +
    +
    +
    +
    +def test_split_cfgs_on_multiple_keys_with_multiple_per_key() +
    +
    +
    +
    +
    +def test_split_cfgs_on_single_key() +
    +
    +
    +
    +
    +def test_split_cfgs_on_single_key_with_multiple_per_key() +
    +
    +
    +
    +
    +
    +
    +
    +
    + +
    + + + diff --git a/docs/saev/test_webapp.html b/docs/saev/test_visuals.html similarity index 95% rename from docs/saev/test_webapp.html rename to docs/saev/test_visuals.html index adf02e0..a6d575c 100644 --- a/docs/saev/test_webapp.html +++ b/docs/saev/test_visuals.html @@ -4,7 +4,7 @@ -saev.test_webapp API documentation +saev.test_visuals API documentation @@ -24,7 +24,7 @@
    -

    Module saev.test_webapp

    +

    Module saev.test_visuals

    @@ -35,7 +35,7 @@

    Module saev.test_webapp

    Functions

    -
    +
    def test_gather_batched_small()
    @@ -58,7 +58,7 @@

    Functions

  • Functions

  • diff --git a/docs/saev/training.html b/docs/saev/training.html index c65fca9..343f165 100644 --- a/docs/saev/training.html +++ b/docs/saev/training.html @@ -36,12 +36,6 @@

    Module saev.training

    Functions

    -
    -def check_cfgs(cfgs: list[Train]) -
    -
    -
    -
    def evaluate(cfgs: list[Train], saes: torch.nn.modules.container.ModuleList) ‑> list[EvalMetrics]
    @@ -65,12 +59,28 @@

    Functions

    +
    +def make_hashable(obj) +
    +
    +
    +
    def make_saes(cfgs: list[SparseAutoencoder]) ‑> tuple[torch.nn.modules.container.ModuleList, list[dict[str, object]]]
    +
    +def split_cfgs(cfgs: list[Train]) ‑> list[list[Train]] +
    +
    +

    Splits configs into groups that can be parallelized.

    +

    Arguments

    +

    A list of configs from a sweep file.

    +

    Returns

    +

    A list of lists, where the configs in each sublist do not differ in any keys that are in CANNOT_PARALLELIZE. This means that each sublist is a valid "parallel" set of configs for train().

    +
    def train(cfgs: list[Train]) ‑> tuple[torch.nn.modules.container.ModuleList, ParallelWandbRun, int]
    @@ -384,11 +394,12 @@

    Methods

  • Functions

  • diff --git a/justfile b/justfile index 3894585..1f65c5f 100644 --- a/justfile +++ b/justfile @@ -4,7 +4,7 @@ docs: lint uv run python scripts/docs.py --pkg-names saev contrib --fpath docs/llms.txt test: lint - uv run pytest --cov saev probing -n auto saev + uv run pytest --cov saev -n auto saev lint: fmt fd -e py | xargs ruff check @@ -12,6 +12,10 @@ lint: fmt fmt: fd -e py | xargs isort fd -e py | xargs ruff format --preview + fd -e elm | xargs elm-format --yes clean: uv run python -c 'import datasets; print(datasets.load_dataset("ILSVRC/imagenet-1k").cleanup_cache_files())' + +build: fmt + cd web && elm make apps/explore/Main.elm --output apps/explore/dist/app.js diff --git a/logbook.md b/logbook.md index f2e486d..83d390a 100644 --- a/logbook.md +++ b/logbook.md @@ -846,3 +846,79 @@ They suggest that measuring sensitivity (how reliably a feature activates for te I want to see which hparam works the best. I need to see how the predictions are to see if the linear model is any good. + +# 12/04/2024 + +I can demonstrate a lot of manipulation in various ways: + +* Linear probe + semantic segmentation +* Masked autoencoder pre-trained ViT with pixel reconstruction +* VLM with language generations +* BioCLIP with probabilities + +I want to build Gradio demos for all of these. +I also want to cherry pick qualitative examples. +And finally, I want to present a set of hparams for training these things at the scale I'm training. + +> We believe the only way to really understand the precise details of a technology is to use it: to see its strengths and weaknesses for yourself, and envision where it could go in the future. + +From https://goodfire.ai/blog/research-preview/ + +Framework for intervention: + +1. Train an SAE on a particular set of vision transformer activations. +2. Train a (or use an existing pre-trained) task-specific head: make predictions based on [CLS] activations, use an LLM to generate captions based on an image and a prompt, etc. +3. Manipulate the vision transformer activations using the SAE-proposed features. +4. Compare task-specific outputs before and after the intervention. + + +## Notes from Meeting + +Talking about faithfulness, causal intervention, etc. +But this method is not supposed to compete with INTR. +It should be a good visual trait extractor. +It should lead to highly factored visual traits. + +We want to show the quality of the visual traits. +How can we demonstrate that? +By manipulating traits for particular classification traits. +This is just a means to an end. + +If two species are inseperable by human eye, does the model find traits that consistently fire/don't fire between the two species to answer that question? + +Can we find a difference between reticulated giraffe species? +What about a lack of differences between two species of red wolves that were recently merged into a single species? + +Experiments + +two "understanding" + +dino vs clip (pre-training modality) +bioclip vs clip (domain/finetuning effects) + +four "control" + +image classification -> CLIP +sem seg -> DINO +mae (image gen) -> MAE +image captioning (vqa? moondream) -> Moondream + +don't compare to protopnet or stuff in intro, it can come up in related work + +So what's stopping us from releasing a preprint? + +1. Experimental results +2. Dashboards +3. Writing + +So I will: + +1. Kick off jobs training on normalized activations for DINOv2 (patch) and CLIP (CLS) +2. Train an SAE on unnormalized activations for CLIP CLS to get something ready. +3. Build gradio demo for SAE inference on CLIP activations. + +# 12/05/2024 + +Instead of gradio or streamlit, I'll just use gradio as an inference endpoint and then Elm -> static html + js as the polished web demos. +I probably should write one pure inference demo in gradio to demonstrate how simple it can be, but for polished, interactive experiences, I want to write the frontends myself. +But this path has many pitfalls---do not get caught up in frontends in favor of writing a paper. diff --git a/main.py b/main.py deleted file mode 100644 index d7b2f1c..0000000 --- a/main.py +++ /dev/null @@ -1,154 +0,0 @@ -import logging -import tomllib -import typing - -import tyro - -import saev - -log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" -logging.basicConfig(level=logging.INFO, format=log_format) - -logger = logging.getLogger("main") - - -def activations(cfg: typing.Annotated[saev.ActivationsConfig, tyro.conf.arg(name="")]): - """ - Save ViT activations for use later on. - - Args: - cfg: Configuration for activations. - """ - import saev.activations - - saev.activations.dump(cfg) - - -def sweep(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")], sweep: str): - """ - Run a grid search over a set of hyperparameters. - - Args: - cfg: Baseline config for training an SAE. - sweep: Path to .toml file defining the sweep parameters. - """ - import submitit - - import saev.config - import saev.training - - with open(sweep, "rb") as fd: - cfgs, errs = saev.config.grid(cfg, tomllib.load(fd)) - - if errs: - for err in errs: - logger.warning("Error in config: %s", err) - return - - logger.info("Sweep has %d experiments.", len(cfgs)) - - if cfg.slurm: - executor = submitit.SlurmExecutor(folder=cfg.log_to) - executor.update_parameters( - time=60, - partition="preemptible", - gpus_per_node=1, - cpus_per_task=cfg.n_workers + 4, - stderr_to_stdout=True, - account=cfg.slurm_acct, - ) - else: - executor = submitit.DebugExecutor(folder=cfg.log_to) - - job = executor.submit(saev.training.main, cfgs) - job.result() - - # for i, result in enumerate(submitit.helpers.as_completed(jobs)): - # exp_id = result.result() - # logger.info("Finished task %s (%d/%d)", exp_id, i + 1, len(jobs)) - - -def train(cfg: typing.Annotated[saev.TrainConfig, tyro.conf.arg(name="")]): - def fn(): - import saev.training - - saev.training.main(cfg) - - import submitit - - if cfg.slurm: - executor = submitit.SlurmExecutor(folder=cfg.log_to) - executor.update_parameters( - time=30, - partition="debug", - gpus_per_node=1, - cpus_per_task=12, - stderr_to_stdout=True, - account=cfg.slurm_acct, - ) - else: - executor = submitit.DebugExecutor(folder=cfg.log_to) - - job = executor.submit(fn) - job.result() - - -def evaluate(cfg: typing.Annotated[saev.EvaluateConfig, tyro.conf.arg(name="")]): - def run_histograms(): - import saev.training - - return saev.training.evaluate(cfg.histograms) - - def run_broden(): - import saev.broden - - return saev.broden.evaluate(cfg.broden) - - def run_imagenet1k(): - import saev.imagenet1k - - return saev.imagenet1k.evaluate(cfg.imagenet) - - import submitit - - if cfg.slurm: - executor = submitit.SlurmExecutor(folder=cfg.log_to) - executor.update_parameters( - time=30, - partition="debug", - gpus_per_node=1, - cpus_per_task=12, - stderr_to_stdout=True, - account=cfg.slurm_acct, - ) - else: - executor = submitit.DebugExecutor(folder=cfg.log_to) - - jobs = [] - # jobs.append(executor.submit(run_histograms)) - # jobs.append(executor.submit(run_broden)) - jobs.append(executor.submit(run_imagenet1k)) - for job in jobs: - job.result() - - -# def webapp(cfg: typing.Annotated[saev.WebappConfig, tyro.conf.arg(name="")]): -# import saev.webapp - -# saev.webapp.main(cfg) - -# print() -# print("To view the webapp, run:") -# print() -# print(" uv run marimo edit webapp.py") -# print() - - -if __name__ == "__main__": - tyro.extras.subcommand_cli_from_dict({ - "activations": activations, - "sweep": sweep, - "evaluate": evaluate, - # "webapp": webapp, - }) - logger.info("Done.") diff --git a/saev/__main__.py b/saev/__main__.py index a519730..4a00221 100644 --- a/saev/__main__.py +++ b/saev/__main__.py @@ -40,12 +40,11 @@ def train( """ import submitit - import saev.config - import saev.training + from . import config, training if sweep is not None: with open(sweep, "rb") as fd: - cfgs, errs = saev.config.grid(cfg, tomllib.load(fd)) + cfgs, errs = config.grid(cfg, tomllib.load(fd)) if errs: for err in errs: @@ -55,6 +54,8 @@ def train( else: cfgs = [cfg] + cfgs = training.split_cfgs(cfgs) + logger.info("Running %d training jobs.", len(cfgs)) if cfg.slurm: @@ -70,8 +71,9 @@ def train( else: executor = submitit.DebugExecutor(folder=cfg.log_to) - job = executor.submit(saev.training.main, cfgs) - job.result() + jobs = [executor.submit(training.main, group) for group in cfgs] + for job in jobs: + job.result() @beartype.beartype diff --git a/saev/helpers.py b/saev/helpers.py index ead7b0e..ab768bc 100644 --- a/saev/helpers.py +++ b/saev/helpers.py @@ -75,3 +75,39 @@ def __len__(self) -> int: # Will throw exception. return len(self.it) + + +################### +# FLATTENED DICTS # +################### + + +@beartype.beartype +def flattened( + dct: dict[str, object], *, sep: str = "." +) -> dict[str, str | int | float | bool | None]: + """ + Flatten a potentially nested dict to a single-level dict with `.`-separated keys. + """ + new = {} + for key, value in dct.items(): + if isinstance(value, dict): + for nested_key, nested_value in flattened(value).items(): + new[key + "." + nested_key] = nested_value + continue + + new[key] = value + + return new + + +@beartype.beartype +def get(dct: dict[str, object], key: str, *, sep: str = ".") -> object: + key = key.split(sep) + key = list(reversed(key)) + + while len(key) > 1: + popped = key.pop() + dct = dct[popped] + + return dct[key.pop()] diff --git a/webapp.py b/saev/interactive/features.py similarity index 100% rename from webapp.py rename to saev/interactive/features.py diff --git a/saev/interactive/metrics.py b/saev/interactive/metrics.py new file mode 100644 index 0000000..f3ec9be --- /dev/null +++ b/saev/interactive/metrics.py @@ -0,0 +1,376 @@ +import marimo + +__generated_with = "0.9.20" +app = marimo.App( + width="medium", + css_file="/home/stevens.994/.config/marimo/custom.css", +) + + +@app.cell +def __(): + import json + import os + + import altair as alt + import beartype + import marimo as mo + import matplotlib.pyplot as plt + import numpy as np + import polars as pl + from jaxtyping import Float, jaxtyped + + import wandb + + return Float, alt, beartype, jaxtyped, json, mo, np, os, pl, plt, wandb + + +@app.cell +def __(mo): + mo.md( + r"""I want to know how points along the reconstruction-fidelity frontier vary in their sparsity-value heatmap. Then I can look at how these heatmaps differ as I change hyperparameters like normalizing \(W_\text{dec}\), etc.""" + ) + return + + +@app.cell +def __(mo): + tag_input = mo.ui.text(value="classification-v1.0") + return (tag_input,) + + +@app.cell +def __(tag_input): + tag_input + return + + +@app.cell +def __(alt, df, mo): + chart = mo.ui.altair_chart( + alt.Chart( + df.select( + "summary/eval/l0", + "summary/losses/mse", + "id", + "config/sae/sparsity_coeff", + "config/lr", + ) + ) + .mark_point() + .encode( + x=alt.X("summary/eval/l0"), + y=alt.Y("summary/losses/mse"), + tooltip=["id", "config/lr"], + color="config/lr:Q", + shape="config/sae/sparsity_coeff:N", + ) + ) + chart + return (chart,) + + +@app.cell +def __(chart, df, mo, np, plot_dist, plt): + sub_df = ( + df.join(chart.value.select("id"), on="id", how="inner") + .sort(by="summary/eval/l0") + .select("id", "summary/eval/freqs", "summary/eval/mean_values") + .head(4) + ) + + mo.stop(len(sub_df) == 0, "Select one or more points.") + + scatter_fig, scatter_axes = plt.subplots( + ncols=len(sub_df), figsize=(12, 3), squeeze=False, sharey=True, sharex=True + ) + + hist_fig, hist_axes = plt.subplots( + ncols=len(sub_df), + nrows=2, + figsize=(12, 6), + squeeze=False, + sharey=True, + sharex=True, + ) + + # Always one row + scatter_axes = scatter_axes.reshape(-1) + hist_axes = hist_axes.T + + for (id, freqs, values), scatter_ax, (freq_hist_ax, values_hist_ax) in zip( + sub_df.iter_rows(), scatter_axes, hist_axes + ): + plot_dist( + freqs.astype(float), + (-6.0, 0.0), + values.astype(float), + (-2.0, 2.0), + scatter_ax, + ) + # ax.scatter(freqs, values, marker=".", alpha=0.03) + # ax.set_yscale("log") + # ax.set_xscale("log") + scatter_ax.set_title(id) + + # Plot feature + bins = np.linspace(-6, 1, 100) + freq_hist_ax.hist(np.log10(freqs.astype(float)), bins=bins) + freq_hist_ax.set_title(f"{id} Feat. Freq. Dist.") + + values_hist_ax.hist(np.log10(values.astype(float)), bins=bins) + values_hist_ax.set_title(f"{id} Mean Val. Distribution") + + scatter_fig.tight_layout() + hist_fig.tight_layout() + return ( + bins, + freq_hist_ax, + freqs, + hist_axes, + hist_fig, + id, + scatter_ax, + scatter_axes, + scatter_fig, + sub_df, + values, + values_hist_ax, + ) + + +@app.cell +def __(scatter_fig): + scatter_fig + return + + +@app.cell +def __(hist_fig): + hist_fig + return + + +@app.cell +def __(chart, df, pl): + df.join(chart.value.select("id"), on="id", how="inner").sort( + by="summary/eval/l0" + ).select("id", pl.selectors.starts_with("config/")) + return + + +@app.cell +def __(Float, beartype, jaxtyped, np): + @jaxtyped(typechecker=beartype.beartype) + def plot_dist( + freqs: Float[np.ndarray, " d_sae"], + freqs_log_range: tuple[float, float], + values: Float[np.ndarray, " d_sae"], + values_log_range: tuple[float, float], + ax, + ): + log_sparsity = np.log10(freqs + 1e-9) + log_values = np.log10(values + 1e-9) + + mask = np.ones(len(log_sparsity)).astype(bool) + min_log_freq, max_log_freq = freqs_log_range + mask[log_sparsity < min_log_freq] = False + mask[log_sparsity > max_log_freq] = False + min_log_value, max_log_value = values_log_range + mask[log_values < min_log_value] = False + mask[log_values > max_log_value] = False + + n_shown = mask.sum() + ax.scatter( + log_sparsity[mask], + log_values[mask], + marker=".", + alpha=0.1, + color="tab:blue", + label=f"Shown ({n_shown})", + ) + n_filtered = (~mask).sum() + ax.scatter( + log_sparsity[~mask], + log_values[~mask], + marker=".", + alpha=0.1, + color="tab:red", + label=f"Filtered ({n_filtered})", + ) + + ax.axvline(min_log_freq, linewidth=0.5, color="tab:red") + ax.axvline(max_log_freq, linewidth=0.5, color="tab:red") + ax.axhline(min_log_value, linewidth=0.5, color="tab:red") + ax.axhline(max_log_value, linewidth=0.5, color="tab:red") + + ax.set_xlabel("Feature Frequency (log10)") + # ax.set_ylabel("Mean Activation Value (log10)") + + return (plot_dist,) + + +@app.cell +def __( + beartype, + get_data_key, + get_model_key, + json, + load_freqs, + load_mean_values, + mo, + os, + pl, + tag_input, + wandb, +): + @beartype.beartype + def make_df(tag: str): + runs = wandb.Api().runs(path="samuelstevens/saev", filters={"config.tag": tag}) + + rows = [] + for run in mo.status.progress_bar( + runs, + remove_on_exit=True, + title="Loading", + subtitle="Parsing runs from WandB", + ): + row = {} + row["id"] = run.id + + row.update(**{ + f"summary/{key}": value for key, value in run.summary.items() + }) + try: + row["summary/eval/freqs"] = load_freqs(run) + except ValueError: + print(f"Run {run.id} did not log eval/freqs.") + continue + except RuntimeError: + print(f"Wandb blew up on run {run.id}.") + continue + try: + row["summary/eval/mean_values"] = load_mean_values(run) + except ValueError: + print(f"Run {run.id} did not log eval/mean_values.") + continue + except RuntimeError: + print(f"Wandb blew up on run {run.id}.") + continue + + # config + row.update(**{ + f"config/data/{key}": value + for key, value in run.config.pop("data").items() + }) + row.update(**{ + f"config/sae/{key}": value + for key, value in run.config.pop("sae").items() + }) + + row.update(**{f"config/{key}": value for key, value in run.config.items()}) + + with open( + os.path.join(row["config/data/shard_root"], "metadata.json") + ) as fd: + metadata = json.load(fd) + + row["model_key"] = get_model_key(metadata) + + data_key = get_data_key(metadata) + if data_key is None: + print("Bad run: {run.id}") + continue + row["data_key"] = data_key + + row["config/d_vit"] = metadata["d_vit"] + rows.append(row) + + if not rows: + return None + + df = pl.DataFrame(rows).with_columns( + (pl.col("config/sae/d_vit") * pl.col("config/sae/exp_factor")).alias( + "config/sae/d_sae" + ) + ) + return df + + df = make_df(tag_input.value) + return df, make_df + + +@app.cell +def __(beartype): + @beartype.beartype + def get_model_key(metadata: dict[str, object]) -> str | None: + family, ckpt = metadata["model_family"], metadata["model_ckpt"] + if family == "dinov2" and ckpt == "dinov2_vitb14_reg": + return "DINOv2 ViT-B/14" + if family == "clip" and ckpt == "ViT-B-16/openai": + return "CLIP ViT-B/16" + + print(f"Unknown model: {(family, ckpt)}") + return None + + @beartype.beartype + def get_data_key(metadata: dict[str, object]) -> str | None: + if "train_mini" in metadata["data"] and "Inat21Dataset" in metadata["data"]: + return "iNat21" + + if "train" in metadata["data"] and "Imagenet" in metadata["data"]: + return "ImageNet-1K" + + print(f"Unknown data: {metadata['data']}") + return None + + return get_data_key, get_model_key + + +@app.cell +def __(Float, json, np, os): + def load_freqs(run) -> Float[np.ndarray, " d_sae"]: + try: + for artifact in run.logged_artifacts(): + if "evalfreqs" not in artifact.name: + continue + + dpath = artifact.download() + fpath = os.path.join(dpath, "eval", "freqs.table.json") + print(fpath) + with open(fpath) as fd: + raw = json.load(fd) + return np.array(raw["data"]).reshape(-1) + except Exception as err: + raise RuntimeError("Wandb sucks.") from err + + raise ValueError(f"freqs not found in run '{run.id}'") + + def load_mean_values(run) -> Float[np.ndarray, " d_sae"]: + try: + for artifact in run.logged_artifacts(): + if "evalmean_values" not in artifact.name: + continue + + dpath = artifact.download() + fpath = os.path.join(dpath, "eval", "mean_values.table.json") + print(fpath) + with open(fpath) as fd: + raw = json.load(fd) + return np.array(raw["data"]).reshape(-1) + except Exception as err: + raise RuntimeError("Wandb sucks.") from err + + raise ValueError(f"mean_values not found in run '{run.id}'") + + return load_freqs, load_mean_values + + +@app.cell +def __(df): + df + return + + +if __name__ == "__main__": + app.run() diff --git a/saev/test_training.py b/saev/test_training.py new file mode 100644 index 0000000..50b2d01 --- /dev/null +++ b/saev/test_training.py @@ -0,0 +1,87 @@ +from . import config, training + + +def test_split_cfgs_on_single_key(): + cfgs = [config.Train(n_workers=12), config.Train(n_workers=16)] + expected = [[config.Train(n_workers=12)], [config.Train(n_workers=16)]] + + actual = training.split_cfgs(cfgs) + + assert actual == expected + + +def test_split_cfgs_on_single_key_with_multiple_per_key(): + cfgs = [ + config.Train(n_patches=12), + config.Train(n_patches=16), + config.Train(n_patches=16), + config.Train(n_patches=16), + ] + expected = [ + [config.Train(n_patches=12)], + [ + config.Train(n_patches=16), + config.Train(n_patches=16), + config.Train(n_patches=16), + ], + ] + + actual = training.split_cfgs(cfgs) + + assert actual == expected + + +def test_split_cfgs_on_multiple_keys_with_multiple_per_key(): + cfgs = [ + config.Train(n_patches=12, track=False), + config.Train(n_patches=12, track=True), + config.Train(n_patches=16, track=True), + config.Train(n_patches=16, track=True), + config.Train(n_patches=16, track=False), + ] + expected = [ + [config.Train(n_patches=12, track=False)], + [config.Train(n_patches=12, track=True)], + [ + config.Train(n_patches=16, track=True), + config.Train(n_patches=16, track=True), + ], + [config.Train(n_patches=16, track=False)], + ] + + actual = training.split_cfgs(cfgs) + + assert actual == expected + + +def test_split_cfgs_no_bad_keys(): + cfgs = [ + config.Train(n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=1e-4)), + config.Train(n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=2e-4)), + config.Train(n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=3e-4)), + config.Train(n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=4e-4)), + config.Train(n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=5e-4)), + ] + expected = [ + [ + config.Train( + n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=1e-4) + ), + config.Train( + n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=2e-4) + ), + config.Train( + n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=3e-4) + ), + config.Train( + n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=4e-4) + ), + config.Train( + n_patches=12, sae=config.SparseAutoencoder(sparsity_coeff=5e-4) + ), + ] + ] + + actual = training.split_cfgs(cfgs) + + assert actual == expected diff --git a/saev/test_webapp.py b/saev/test_visuals.py similarity index 85% rename from saev/test_webapp.py rename to saev/test_visuals.py index 1b5936c..e6b1011 100644 --- a/saev/test_webapp.py +++ b/saev/test_visuals.py @@ -1,12 +1,12 @@ import torch -from . import webapp +from . import visuals def test_gather_batched_small(): values = torch.arange(0, 64, dtype=torch.float).view(4, 2, 8) i = torch.tensor([[0], [0], [1], [1]]) - actual = webapp.gather_batched(values, i) + actual = visuals.gather_batched(values, i) expected = torch.tensor([ [[0, 1, 2, 3, 4, 5, 6, 7]], diff --git a/saev/training.py b/saev/training.py index 1b6f163..2d3e8f4 100644 --- a/saev/training.py +++ b/saev/training.py @@ -2,7 +2,6 @@ Trains many SAEs in parallel to amortize the cost of loading a single batch of data over many SAE training runs. """ -import collections import dataclasses import json import logging @@ -154,11 +153,14 @@ def train( """ Explicitly declare the optimizer, schedulers, dataloader, etc outside of `main` so that all the variables are dropped from scope and can be garbage collected. """ - check_cfgs(cfgs) + if len(split_cfgs(cfgs)) != 1: + raise ValueError("Configs are not parallelizeable: {cfgs}.") err_msg = "ghost grads are disabled in current codebase." assert all(not c.sae.ghost_grads for c in cfgs), err_msg + logger.info("Parallelizing %d runs.", len(cfgs)) + cfg = cfgs[0] if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than @@ -302,12 +304,9 @@ def evaluate(cfgs: list[config.Train], saes: torch.nn.ModuleList) -> list[EvalMe torch.cuda.empty_cache() - check_cfgs(cfgs) + if len(split_cfgs(cfgs)) != 1: + raise ValueError("Configs are not parallelizeable: {cfgs}.") - # Also manually check that all SAEs are the same dimension - d_sae = saes[0].cfg.d_sae - for sae in saes: - assert sae.cfg.d_sae == d_sae saes.eval() cfg = cfgs[0] @@ -407,6 +406,7 @@ def __iter__(self): "slurm_acct", "log_to", "sae.exp_factor", + "sae.d_vit", ]) @@ -425,22 +425,25 @@ def split_cfgs(cfgs: list[config.Train]) -> list[list[config.Train]]: groups = {} for cfg in cfgs: dct = dataclasses.asdict(cfg) - dct = helpers.flattened(dct) - + # Create a key tuple from the values of CANNOT_PARALLELIZE keys key_values = [] for key in sorted(CANNOT_PARALLELIZE): - key_values.append((key, dct[key])) + key_values.append((key, make_hashable(helpers.get(dct, key)))) group_key = tuple(key_values) - + if group_key not in groups: groups[group_key] = [] groups[group_key].append(cfg) - + # Convert groups dict to list of lists return list(groups.values()) +def make_hashable(obj): + return json.dumps(obj, sort_keys=True) + + ############## # Schedulers # ############## diff --git a/saev/web/probing.py b/saev/web/probing.py deleted file mode 100644 index 3ee91d5..0000000 --- a/saev/web/probing.py +++ /dev/null @@ -1,227 +0,0 @@ -import marimo - -__generated_with = "0.9.14" -app = marimo.App(width="full") - - -@app.cell -def __(): - import itertools - import os - import pickle - import random - - import marimo as mo - import matplotlib.pyplot as plt - import numpy as np - import torch - - return itertools, mo, np, os, pickle, plt, random, torch - - -@app.cell -def __(mo, os): - ckpts = os.listdir("/research/nfs_su_809/workspace/stevens.994/saev/probing") - - mo.stop( - not ckpts, - mo.md("Run `uv run main.py probe --help` to fill out at least one checkpoint."), - ) - - ckpt_dropdown = mo.ui.dropdown(ckpts, label="Checkpoint:", value=ckpts[0]) - return ckpt_dropdown, ckpts - - -@app.cell -def __(ckpt_dropdown): - ckpt_dropdown - return - - -@app.cell -def __(ckpt_dropdown, mo, os): - mo.stop( - ckpt_dropdown.value is None, - mo.md("Run `uv run main.py probe --help` to fill out at least one checkpoint."), - ) - - tasks = os.listdir( - f"/research/nfs_su_809/workspace/stevens.994/saev/probing/{ckpt_dropdown.value}" - ) - - mo.stop( - not tasks, - mo.md("Run `uv run main.py probe --help` to fill out at least one checkpoint."), - ) - - task_dropdown = mo.ui.dropdown(tasks, label="Task:", value=tasks[0]) - return task_dropdown, tasks - - -@app.cell -def __(task_dropdown): - task_dropdown - return - - -@app.cell -def __(ckpt_dropdown, mo, os, task_dropdown): - root = os.path.join( - "/research/nfs_su_809/workspace/stevens.994/saev/probing", - ckpt_dropdown.value, - task_dropdown.value, - ) - - get_neuron_i, set_neuron_i = mo.state(0) - return get_neuron_i, root, set_neuron_i - - -@app.cell -def __(mo, os, root): - neuron_indices = [ - int(name) for name in os.listdir(f"{root}/neurons") if name.isdigit() - ] - neuron_indices = sorted(neuron_indices) - mo.md(f"Found {len(neuron_indices)} saved neurons.") - return (neuron_indices,) - - -@app.cell -def __(mo, neuron_indices, set_neuron_i): - next_button = mo.ui.button( - label="Next", - on_change=lambda _: set_neuron_i(lambda v: (v + 1) % len(neuron_indices)), - ) - - prev_button = mo.ui.button( - label="Previous", - on_change=lambda _: set_neuron_i(lambda v: (v - 1) % len(neuron_indices)), - ) - return next_button, prev_button - - -@app.cell -def __(get_neuron_i, mo, neuron_indices, set_neuron_i): - neuron_slider = mo.ui.slider( - 0, - len(neuron_indices) - 1, - value=get_neuron_i(), - on_change=lambda i: set_neuron_i(i), - full_width=True, - ) - return (neuron_slider,) - - -@app.cell -def __(mo): - width_slider = mo.ui.slider(start=1, stop=20, label="Images per row", value=8) - return (width_slider,) - - -@app.cell -def __( - get_neuron_i, - mo, - n, - neuron_indices, - neuron_notes, - neuron_slider, - next_button, - prev_button, - width_slider, -): - label = f"Neuron {n} ({get_neuron_i()}/{len(neuron_indices)}; {get_neuron_i() / len(neuron_indices) * 100:.2f}%)" - - mo.md(f""" - {mo.hstack([prev_button, next_button, label, width_slider], justify="start", gap=1.0)} - {neuron_slider} - - Notes on Neuron {n}: {neuron_notes} - """) - return (label,) - - -@app.cell -def __(mo, root): - def show_img(n: int, i: int): - label = "No label found." - try: - label = open(f"{root}/neurons/{n}/{i}.txt").read().strip() - except FileNotFoundError: - return mo.md(f"*Missing image {i + 1}*") - - return mo.vstack([ - mo.image(f"{root}/neurons/{n}/{i}-original.png"), - mo.image(f"{root}/neurons/{n}/{i}.png"), - mo.md(label), - ]) - - return (show_img,) - - -@app.cell -def __( - batched, - get_neuron_i, - mo, - neuron_indices, - os, - root, - show_img, - width_slider, -): - n = neuron_indices[get_neuron_i()] - i_im = [ - int(filename.removesuffix(".txt")) - for filename in os.listdir(f"{root}/neurons/{n}") - if filename.endswith(".txt") and filename != "notes.txt" - ][:200] - - imgs = [show_img(n, i) for i in i_im] - - rows = [ - mo.hstack(batch, widths="equal") - for batch in batched(imgs, n=width_slider.value) - ] - - mo.vstack(rows) - return i_im, imgs, n, rows - - -@app.cell -def __(n, root): - try: - with open(f"{root}/neurons/{n}/notes.txt") as fd: - neuron_notes = fd.read().strip() - except FileNotFoundError: - neuron_notes = "*no notes.*" - return fd, neuron_notes - - -@app.cell -def __(itertools): - def batched(iterable, n, *, strict=False): - # batched('ABCDEFG', 3) → ABC DEF G - if n < 1: - raise ValueError("n must be at least one") - iterator = iter(iterable) - while batch := tuple(itertools.islice(iterator, n)): - if strict and len(batch) != n: - raise ValueError("batched(): incomplete batch") - yield batch - - return (batched,) - - -@app.cell -def __(): - return - - -@app.cell -def __(): - return - - -if __name__ == "__main__": - app.run()