From 3e8132b00c99b57d62e8b487d4bdace7bef9a19c Mon Sep 17 00:00:00 2001 From: Samuel Stevens Date: Fri, 25 Oct 2024 09:46:05 -0400 Subject: [PATCH] Updating docs --- docs/saev/config.html | 141 +++++++++++++----- docs/saev/index.html | 60 +++++--- docs/saev/modeling.html | 320 +++++++++++++++------------------------- docs/saev/sessions.html | 57 ------- docs/saev/vits.html | 57 ------- saev/__init__.py | 6 + saev/sessions.py | 1 - saev/vits.py | 1 - 8 files changed, 260 insertions(+), 383 deletions(-) delete mode 100644 docs/saev/sessions.html delete mode 100644 docs/saev/vits.html delete mode 100644 saev/sessions.py delete mode 100644 saev/vits.py diff --git a/docs/saev/config.html b/docs/saev/config.html index 16ecaea..e731249 100644 --- a/docs/saev/config.html +++ b/docs/saev/config.html @@ -39,7 +39,7 @@

Classes

class Config -(image_width: int = 224, image_height: int = 224, model: str = 'ViT-L-14/openai', module_name: str = 'resid', block_layer: int = -2, data: Huggingface | Webdataset = <factory>, n_workers: int = 8, d_in: int = 1024, n_epochs: int = 3, n_batches_in_store: int = 15, vit_batch_size: int = 1024, expansion_factor: int = 64, l1_coefficient: float = 8e-05, lr: float = 0.0004, lr_warm_up_steps: int = 500, batch_size: int = 1024, use_ghost_grads: bool = True, feature_sampling_window: int = 64, resample_batches: int = 32, feature_reinit_scale: float = 0.2, dead_feature_window: int = 64, dead_feature_estimation_method: str = 'no_fire', dead_feature_threshold: float = 1e-06, log_to_wandb: bool = True, wandb_project: str = 'saev', wandb_log_freq: int = 10, device: str = 'cuda', seed: int = 42, dtype: str = 'float32', checkpoint_path: str = 'checkpoints') +(image_width: int = 224, image_height: int = 224, model: str = 'ViT-L-14/openai', module_name: str = 'resid', block_layer: int = -2, data: Imagenet | TreeOfLife = <factory>, n_workers: int = 8, d_vit: int = 1024, n_epochs: int = 3, n_batches_in_store: int = 15, vit_batch_size: int = 1024, expansion_factor: int = 64, l1_coefficient: float = 8e-05, lr: float = 0.0004, lr_warm_up_steps: int = 500, batch_size: int = 1024, use_ghost_grads: bool = True, feature_sampling_window: int = 64, resample_batches: int = 32, feature_reinit_scale: float = 0.2, dead_feature_window: int = 64, dead_feature_estimation_method: str = 'no_fire', dead_feature_threshold: float = 1e-06, log_to_wandb: bool = True, wandb_project: str = 'saev', wandb_log_freq: int = 10, device: str = 'cuda', seed: int = 42, dtype: str = 'float32', checkpoint_path: str = 'checkpoints', slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')

Configuration for training a sparse autoencoder on a vision transformer.

@@ -58,14 +58,16 @@

Classes

image_width: int = 224 image_height: int = 224 model: str = "ViT-L-14/openai" + """Model string, for use with open_clip.""" module_name: str = "resid" block_layer: int = -2 - data: Huggingface | Webdataset = dataclasses.field(default_factory=Huggingface) + data: Imagenet | TreeOfLife = dataclasses.field(default_factory=Imagenet) + """Which dataset to use.""" n_workers: int = 8 """Number of dataloader workers.""" # SAE Parameters - d_in: int = 1024 + d_vit: int = 1024 # Activation Store Parameters n_epochs: int = 3 @@ -101,13 +103,19 @@

Classes

dtype: str = "float32" checkpoint_path: str = "checkpoints" + slurm: bool = False + """Whether to use submitit to run jobs on a slurm cluster.""" + slurm_acct: str = "PAS2136" + """Slurm account string.""" + log_to: str = "./logs" + @property def store_size(self) -> int: return self.n_batches_in_store * self.batch_size @property def d_sae(self) -> int: - return self.d_in * self.expansion_factor + return self.d_vit * self.expansion_factor @property def run_name(self) -> str: @@ -129,13 +137,13 @@

Class variables

-
var d_in : int
+
var d_vit : int
-
var dataHuggingface | Webdataset
+
var dataImagenet | TreeOfLife
-
+

Which dataset to use.

var dead_feature_estimation_method : str
@@ -181,6 +189,10 @@

Class variables

+
var log_to : str
+
+
+
var log_to_wandb : bool
@@ -195,7 +207,7 @@

Class variables

var model : str
-
+

Model string, for use with open_clip.

var module_name : str
@@ -221,6 +233,14 @@

Class variables

+
var slurm : bool
+
+

Whether to use submitit to run jobs on a slurm cluster.

+
+
var slurm_acct : str
+
+

Slurm account string.

+
var use_ghost_grads : bool
@@ -249,7 +269,7 @@

Instance variables

@property
 def d_sae(self) -> int:
-    return self.d_in * self.expansion_factor
+ return self.d_vit * self.expansion_factor
prop run_name : str
@@ -280,22 +300,23 @@

Instance variables

-
-class Huggingface +
+class Imagenet (name: str = 'ILSVRC/imagenet-1k')
-

Configuration for datasets from HuggingFace.

+

Configuration for HuggingFace Imagenet.

Expand source code
@beartype.beartype
 @dataclasses.dataclass(frozen=True)
-class Huggingface:
-    """Configuration for datasets from HuggingFace."""
+class Imagenet:
+    """Configuration for HuggingFace Imagenet."""
 
     name: str = "ILSVRC/imagenet-1k"
+    """Dataset name. Probably don't want to change this."""
 
     @property
     def n_imgs(self) -> int:
@@ -308,14 +329,14 @@ 

Instance variables

Class variables

-
var name : str
+
var name : str
-
+

Dataset name. Probably don't want to change this.

Instance variables

-
prop n_imgs : int
+
prop n_imgs : int
@@ -334,27 +355,27 @@

Instance variables

-
-class Webdataset -(url: str = '/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar', n_imgs: int = 9562377) +
+class TreeOfLife +(metadata: str = 'treeoflife-10m.json', label_key: str = '.taxonomic_name.txt')
-

Configuration for webdataset (like TreeOfLife-10M).

+

Configuration for the TreeOfLife-10M webdataset.

Webdatasets are designed for random sampling of the entire dataset so that over multiple epochs, every sample is seen, on average, the same number of times. However, for training sparse autoencoders, we need to calculate ViT activations exactly once for each example in the dataset. Webdatasets support this through the wids library.

Here is a short discussion of the steps required to use saev with webdatasets.

First, you will need to use widsindex (installed with the webdataset library) to create an metadata file used by wids. You can see an example file here. To generate my own metadata file, I ran this command:

-
uv run widsindex create --name treeoflife-10m --output meta.json '/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar'
+
uv run widsindex create         --name treeoflife-10m         --output treeoflife-10m.json         '/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar'
 
-

It took a long time (more than an hour) and generated a meta.json file.

+

It took a long time (more than an hour, less than 3 hours) and generated a treeoflife-10m.json file.

Expand source code
@beartype.beartype
 @dataclasses.dataclass(frozen=True)
-class Webdataset:
+class TreeOfLife:
     """
-    Configuration for webdataset (like TreeOfLife-10M).
+    Configuration for the TreeOfLife-10M webdataset.
 
     Webdatasets are designed for random sampling of the entire dataset so that over multiple epochs, every sample is seen, on average, the same number of times. However, for training sparse autoencoders, we need to calculate ViT activations exactly once for each example in the dataset. Webdatasets support this through the [`wids`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-wids-library-for-indexed-webdatasets) library.
 
@@ -363,28 +384,64 @@ 

Instance variables

First, you will need to use `widsindex` (installed with the webdataset library) to create an metadata file used by wids. You can see an example file [here](https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train.json). To generate my own metadata file, I ran this command: ``` - uv run widsindex create --name treeoflife-10m --output meta.json '/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar' + uv run widsindex create \ + --name treeoflife-10m \ + --output treeoflife-10m.json \ + '/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar' ``` - It took a long time (more than an hour) and generated a `meta.json` file. + It took a long time (more than an hour, less than 3 hours) and generated a `treeoflife-10m.json` file. """ - url: str = "/fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/224x224/train/shard-{000000..000159}.tar" + metadata: str = "treeoflife-10m.json" """Path to dataset shards.""" - n_imgs: int = 9562377 - """Number of images in dataset."""
+ label_key: str = ".taxonomic_name.txt" + """Which key to use as the label.""" + + @property + def n_imgs(self) -> int: + with open(self.metadata) as fd: + metadata = json.load(fd) + + return ( + np.array([shard["nsamples"] for shard in metadata["shardlist"]]) + .sum() + .item() + )

Class variables

-
var n_imgs : int
+
var label_key : str
-

Number of images in dataset.

+

Which key to use as the label.

-
var url : str
+
var metadata : str

Path to dataset shards.

+

Instance variables

+
+
prop n_imgs : int
+
+
+
+ +Expand source code + +
@property
+def n_imgs(self) -> int:
+    with open(self.metadata) as fd:
+        metadata = json.load(fd)
+
+    return (
+        np.array([shard["nsamples"] for shard in metadata["shardlist"]])
+        .sum()
+        .item()
+    )
+
+
+
@@ -407,8 +464,8 @@

Configbatch_size
  • block_layer
  • checkpoint_path
  • -
  • d_in
  • d_sae
  • +
  • d_vit
  • data
  • dead_feature_estimation_method
  • dead_feature_threshold
  • @@ -421,6 +478,7 @@

    Configimage_height
  • image_width
  • l1_coefficient
  • +
  • log_to
  • log_to_wandb
  • lr
  • lr_warm_up_steps
  • @@ -432,6 +490,8 @@

    Configresample_batches
  • run_name
  • seed
  • +
  • slurm
  • +
  • slurm_acct
  • store_size
  • use_ghost_grads
  • vit_batch_size
  • @@ -440,17 +500,18 @@

    Config
  • -

    Huggingface

    +

    Imagenet

  • -

    Webdataset

    +

    TreeOfLife

  • diff --git a/docs/saev/index.html b/docs/saev/index.html index 6613ecd..f8376f4 100644 --- a/docs/saev/index.html +++ b/docs/saev/index.html @@ -5,7 +5,7 @@ saev API documentation - + @@ -27,6 +27,8 @@

    Package saev

    +

    saev is a Python package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch.

    +

    The main entrypoint to the package is in main.py; use python main.py --help to see the options and documentation for the script.

    Sub-modules

    @@ -45,20 +47,13 @@

    Sub-modules

    saev.modeling
    -
    -
    -
    saev.sessions
    -
    -
    +

    modeling is the main module for the saev package and contains all the important non-config classes. +It's fine for this package to be slow to import …

    saev.training
    -
    saev.vits
    -
    -
    -
    saev.webapp
    @@ -74,7 +69,7 @@

    Classes

    class Config -(image_width: int = 224, image_height: int = 224, model: str = 'ViT-L-14/openai', module_name: str = 'resid', block_layer: int = -2, data: Huggingface | Webdataset = <factory>, n_workers: int = 8, d_in: int = 1024, n_epochs: int = 3, n_batches_in_store: int = 15, vit_batch_size: int = 1024, expansion_factor: int = 64, l1_coefficient: float = 8e-05, lr: float = 0.0004, lr_warm_up_steps: int = 500, batch_size: int = 1024, use_ghost_grads: bool = True, feature_sampling_window: int = 64, resample_batches: int = 32, feature_reinit_scale: float = 0.2, dead_feature_window: int = 64, dead_feature_estimation_method: str = 'no_fire', dead_feature_threshold: float = 1e-06, log_to_wandb: bool = True, wandb_project: str = 'saev', wandb_log_freq: int = 10, device: str = 'cuda', seed: int = 42, dtype: str = 'float32', checkpoint_path: str = 'checkpoints') +(image_width: int = 224, image_height: int = 224, model: str = 'ViT-L-14/openai', module_name: str = 'resid', block_layer: int = -2, data: Imagenet | TreeOfLife = <factory>, n_workers: int = 8, d_vit: int = 1024, n_epochs: int = 3, n_batches_in_store: int = 15, vit_batch_size: int = 1024, expansion_factor: int = 64, l1_coefficient: float = 8e-05, lr: float = 0.0004, lr_warm_up_steps: int = 500, batch_size: int = 1024, use_ghost_grads: bool = True, feature_sampling_window: int = 64, resample_batches: int = 32, feature_reinit_scale: float = 0.2, dead_feature_window: int = 64, dead_feature_estimation_method: str = 'no_fire', dead_feature_threshold: float = 1e-06, log_to_wandb: bool = True, wandb_project: str = 'saev', wandb_log_freq: int = 10, device: str = 'cuda', seed: int = 42, dtype: str = 'float32', checkpoint_path: str = 'checkpoints', slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')

    Configuration for training a sparse autoencoder on a vision transformer.

    @@ -93,14 +88,16 @@

    Classes

    image_width: int = 224 image_height: int = 224 model: str = "ViT-L-14/openai" + """Model string, for use with open_clip.""" module_name: str = "resid" block_layer: int = -2 - data: Huggingface | Webdataset = dataclasses.field(default_factory=Huggingface) + data: Imagenet | TreeOfLife = dataclasses.field(default_factory=Imagenet) + """Which dataset to use.""" n_workers: int = 8 """Number of dataloader workers.""" # SAE Parameters - d_in: int = 1024 + d_vit: int = 1024 # Activation Store Parameters n_epochs: int = 3 @@ -136,13 +133,19 @@

    Classes

    dtype: str = "float32" checkpoint_path: str = "checkpoints" + slurm: bool = False + """Whether to use submitit to run jobs on a slurm cluster.""" + slurm_acct: str = "PAS2136" + """Slurm account string.""" + log_to: str = "./logs" + @property def store_size(self) -> int: return self.n_batches_in_store * self.batch_size @property def d_sae(self) -> int: - return self.d_in * self.expansion_factor + return self.d_vit * self.expansion_factor @property def run_name(self) -> str: @@ -164,13 +167,13 @@

    Class variables

    -
    var d_in : int
    +
    var d_vit : int
    -
    var dataHuggingface | Webdataset
    +
    var dataImagenet | TreeOfLife
    -
    +

    Which dataset to use.

    var dead_feature_estimation_method : str
    @@ -216,6 +219,10 @@

    Class variables

    +
    var log_to : str
    +
    +
    +
    var log_to_wandb : bool
    @@ -230,7 +237,7 @@

    Class variables

    var model : str
    -
    +

    Model string, for use with open_clip.

    var module_name : str
    @@ -256,6 +263,14 @@

    Class variables

    +
    var slurm : bool
    +
    +

    Whether to use submitit to run jobs on a slurm cluster.

    +
    +
    var slurm_acct : str
    +
    +

    Slurm account string.

    +
    var use_ghost_grads : bool
    @@ -284,7 +299,7 @@

    Instance variables

    @property
     def d_sae(self) -> int:
    -    return self.d_in * self.expansion_factor
    + return self.d_vit * self.expansion_factor
    prop run_name : str
    @@ -329,9 +344,7 @@

    Instance variables

  • saev.config
  • saev.helpers
  • saev.modeling
  • -
  • saev.sessions
  • saev.training
  • -
  • saev.vits
  • saev.webapp
  • @@ -343,8 +356,8 @@

    Config

  • batch_size
  • block_layer
  • checkpoint_path
  • -
  • d_in
  • d_sae
  • +
  • d_vit
  • data
  • dead_feature_estimation_method
  • dead_feature_threshold
  • @@ -357,6 +370,7 @@

    Config

  • image_height
  • image_width
  • l1_coefficient
  • +
  • log_to
  • log_to_wandb
  • lr
  • lr_warm_up_steps
  • @@ -368,6 +382,8 @@

    Config

  • resample_batches
  • run_name
  • seed
  • +
  • slurm
  • +
  • slurm_acct
  • store_size
  • use_ghost_grads
  • vit_batch_size
  • diff --git a/docs/saev/modeling.html b/docs/saev/modeling.html index 9ed64b1..d22e82c 100644 --- a/docs/saev/modeling.html +++ b/docs/saev/modeling.html @@ -5,7 +5,8 @@ saev.modeling API documentation - + @@ -27,6 +28,8 @@

    Module saev.modeling

    +

    modeling is the main module for the saev package and contains all the important non-config classes. +It's fine for this package to be slow to import (see saev.config for a discussion of import times).

    @@ -35,8 +38,8 @@

    Module saev.modeling

    Functions

    -
    -def filter_no_caption_or_no_image(sample) +
    +def dump(filename: str, model_kwargs: dict[str, object], model: torch.nn.modules.module.Module)
    @@ -45,7 +48,14 @@

    Functions

    def get_acts_filepath(cfg: Config) ‑> str
    -
    +

    Return the activations filepath based on the relevant values of a config.

    +

    Args

    +
    +
    cfg
    +
    Config for experiment.
    +
    +

    Returns

    +

    Filepath to where activations should be dumped/loaded from.

    def get_cache_dir() ‑> str @@ -53,8 +63,8 @@

    Functions

    Get cache directory from environment variables, defaulting to the current working directory (.)

    -
    -def get_hf_dataloader(cfg: Config, preprocess) ‑> torch.utils.data.dataloader.DataLoader +
    +def get_imagenet_dataloader(cfg: Config, preprocess) ‑> torch.utils.data.dataloader.DataLoader
    @@ -63,19 +73,36 @@

    Functions

    def get_sae_batches(cfg: Config, acts_store: CachedActivationsStore) ‑> jaxtyping.Float[Tensor, 'store_size d_model']
    -

    Get a batch of vit activations

    -
    -
    -def get_wds_dataloader(cfg: Config, preprocess) ‑> torch.utils.data.dataloader.DataLoader +

    Get a batch of vit activations to re-initialize the SAE.

    +

    Args

    +
    +
    cfg
    +
    Config.
    +
    acts_store
    +
    Activation store.
    +
    +
    +
    +def get_tol_dataloader(cfg: Config, preprocess) ‑> torch.utils.data.dataloader.DataLoader
    -
    +

    Get a dataloader for the TreeOfLife-10M dataset.

    +

    Currently does not include a true index or label in the loaded examples.

    +

    Args

    +
    +
    cfg
    +
    Config.
    +
    preprocess
    +
    Image transform to be applied to each image.
    +
    +

    Returns

    +

    A PyTorch Dataloader that yields dictionaries with 'image' keys containing image batches.

    -
    -def log_and_continue(exn) +
    +def load(filename, cls: type[torch.nn.modules.module.Module]) ‑> torch.nn.modules.module.Module
    -

    Call in an exception handler to ignore any exception, issue a warning, and continue.

    +
    def save_acts(cfg: Config, vit: RecordedVit) @@ -136,7 +163,7 @@

    Classes

    else: raise ValueError(f"Invalid value '{on_missing}' for arg 'on_missing'.") - self.shape = (cfg.data.n_imgs, cfg.d_in) + self.shape = (cfg.data.n_imgs, cfg.d_vit) # TODO # self.labels = torch.tensor(dataset["label"]) self.labels = None @@ -353,16 +380,21 @@

    Methods

    class Session -(vit: RecordedVit, sae: SparseAutoencoder, acts_store: CachedActivationsStore) +(cfg: Config, vit: RecordedVit, sae: SparseAutoencoder, acts_store: CachedActivationsStore)
    -

    Session(vit, sae, acts_store)

    +

    Session is a group of instances of the main classes for saev experiments.

    Expand source code
    @beartype.beartype
     class Session(typing.NamedTuple):
    +    """
    +    Session is a group of instances of the main classes for saev experiments.
    +    """
    +
    +    cfg: config.Config
         vit: RecordedVit
         sae: SparseAutoencoder
         acts_store: CachedActivationsStore
    @@ -378,7 +410,7 @@ 

    Methods

    sae = SparseAutoencoder(cfg) acts_store = CachedActivationsStore(cfg, vit, on_missing="error") - return cls(vit, sae, acts_store) + return cls(cfg, vit, sae, acts_store) @classmethod def from_disk(cls, path) -> "Session": @@ -389,7 +421,7 @@

    Methods

    vit, _, acts_store = cls.from_cfg(cfg) sae = SparseAutoencoder.load_from_pretrained(path) - return cls(vit, sae, acts_store)
    + return cls(cfg, vit, sae, acts_store)

    Ancestors

      @@ -414,50 +446,28 @@

      Instance variables

      var acts_storeCachedActivationsStore
      -

      Alias for field number 2

      +

      Alias for field number 3

      +
      +
      var cfgConfig
      +
      +

      Alias for field number 0

      var saeSparseAutoencoder
      -

      Alias for field number 1

      +

      Alias for field number 2

      var vitRecordedVit
      -

      Alias for field number 0

      +

      Alias for field number 1

    class SparseAutoencoder -(cfg: Config) +(d_vit: int, d_sae: int, l1_coeff: float, use_ghost_grads: bool)
    -

    Base class for all neural network modules.

    -

    Your models should also subclass this class.

    -

    Modules can also contain other Modules, allowing to nest them in -a tree structure. You can assign the submodules as regular attributes::

    -
    import torch.nn as nn
    -import torch.nn.functional as F
    -
    -class Model(nn.Module):
    -    def __init__(self) -> None:
    -        super().__init__()
    -        self.conv1 = nn.Conv2d(1, 20, 5)
    -        self.conv2 = nn.Conv2d(20, 20, 5)
    -
    -    def forward(self, x):
    -        x = F.relu(self.conv1(x))
    -        return F.relu(self.conv2(x))
    -
    -

    Submodules assigned in this way will be registered, and will have their -parameters converted too when you call :meth:to, etc.

    -
    -

    Note

    -

    As per the example above, an __init__() call to the parent class -must be made before assignment on the child.

    -
    -

    :ivar training: Boolean represents whether this module is in training or -evaluation mode. -:vartype training: bool

    +

    Sparse auto-encoder (SAE) using L1 sparsity penalty.

    Initialize internal Module state, shared by both nn.Module and ScriptModule.

    @@ -465,58 +475,49 @@

    Instance variables

    @beartype.beartype
     class SparseAutoencoder(torch.nn.Module):
    -    def __init__(self, cfg: config.Config):
    +    """
    +    Sparse auto-encoder (SAE) using L1 sparsity penalty.
    +    """
    +
    +    l1_coeff: float
    +    use_ghost_grads: bool
    +
    +    def __init__(self, d_vit: int, d_sae: int, l1_coeff: float, use_ghost_grads: bool):
             super().__init__()
    -        if not isinstance(cfg.d_in, int):
    -            raise ValueError(
    -                f"d_in must be an int but was {cfg.d_in=}; {type(cfg.d_in)=}"
    -            )
     
    -        self.cfg = cfg
    -        self.l1_coefficient = cfg.l1_coefficient
    -        self.dtype = cfg.dtype
    -        self.device = cfg.device
    +        self.l1_coeff = l1_coeff
    +        self.use_ghost_grads = use_ghost_grads
     
    +        # Initialize the weights.
             # NOTE: if using resampling neurons method, you must ensure that we initialise the weights in the order W_enc, b_enc, W_dec, b_dec
             self.W_enc = torch.nn.Parameter(
    -            torch.nn.init.kaiming_uniform_(
    -                torch.empty(cfg.d_in, cfg.d_sae, dtype=self.dtype, device=self.device)
    -            )
    -        )
    -        self.b_enc = torch.nn.Parameter(
    -            torch.zeros(cfg.d_sae, dtype=self.dtype, device=self.device)
    +            torch.nn.init.kaiming_uniform_(torch.empty(d_vit, d_sae))
             )
    +        self.b_enc = torch.nn.Parameter(torch.zeros(d_sae))
     
             self.W_dec = torch.nn.Parameter(
    -            torch.nn.init.kaiming_uniform_(
    -                torch.empty(cfg.d_sae, cfg.d_in, dtype=self.dtype, device=self.device)
    -            )
    +            torch.nn.init.kaiming_uniform_(torch.empty(d_sae, d_vit))
             )
     
             with torch.no_grad():
                 # Anthropic normalizes this to have unit columns
                 self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
     
    -        self.b_dec = torch.nn.Parameter(
    -            torch.zeros(cfg.d_in, dtype=self.dtype, device=self.device)
    -        )
    +        self.b_dec = torch.nn.Parameter(torch.zeros(d_vit))
     
         @jaxtyped(typechecker=beartype.beartype)
         def forward(self, x: Float[Tensor, "batch d_model"], dead_neuron_mask=None):
    -        # move x to correct dtype
    -        x = x.to(self.dtype)
    -
             # Remove encoder bias as per Anthropic
             h_pre = (
                 einops.einsum(
    -                x - self.b_dec, self.W_enc, "... d_in, d_in d_sae -> ... d_sae"
    +                x - self.b_dec, self.W_enc, "... d_vit, d_vit d_sae -> ... d_sae"
                 )
                 + self.b_enc
             )
             f_x = torch.nn.functional.relu(h_pre)
     
             x_hat = (
    -            einops.einsum(f_x, self.W_dec, "... d_sae, d_sae d_in -> ... d_in")
    +            einops.einsum(f_x, self.W_dec, "... d_sae, d_sae d_vit -> ... d_vit")
                 + self.b_dec
             )
     
    @@ -525,9 +526,9 @@ 

    Instance variables

    torch.pow((x_hat - x.float()), 2) / (x**2).sum(dim=-1, keepdim=True).sqrt() ) - mse_loss_ghost_resid = torch.tensor(0.0, dtype=self.dtype, device=self.device) + ghost_loss = torch.tensor(0.0, dtype=mse_loss.dtype, device=mse_loss.device) # gate on config and training so evals is not slowed down. - if self.cfg.use_ghost_grads and self.training and dead_neuron_mask.sum() > 0: + if self.use_ghost_grads and self.training and dead_neuron_mask.sum() > 0: assert dead_neuron_mask is not None # ghost protocol @@ -541,42 +542,38 @@

    Instance variables

    ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :] l2_norm_ghost_out = torch.norm(ghost_out, dim=-1) norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2) - ghost_out = ghost_out * norm_scaling_factor[:, None].detach() + ghost_out *= norm_scaling_factor[:, None].detach() # 3. - mse_loss_ghost_resid = ( + ghost_loss = ( torch.pow((ghost_out - residual.detach().float()), 2) / (residual.detach() ** 2).sum(dim=-1, keepdim=True).sqrt() ) - mse_rescaling_factor = (mse_loss / (mse_loss_ghost_resid + 1e-6)).detach() - mse_loss_ghost_resid = mse_rescaling_factor * mse_loss_ghost_resid + mse_rescaling_factor = (mse_loss / (ghost_loss + 1e-6)).detach() + ghost_loss *= mse_rescaling_factor - mse_loss_ghost_resid = mse_loss_ghost_resid.mean() + ghost_loss = ghost_loss.mean() mse_loss = mse_loss.mean() sparsity = torch.abs(f_x).sum(dim=1).mean(dim=(0,)) - l1_loss = self.l1_coefficient * sparsity - loss = mse_loss + l1_loss + mse_loss_ghost_resid + l1_loss = self.l1_coeff * sparsity + loss = mse_loss + l1_loss + ghost_loss - return x_hat, f_x, loss, mse_loss, l1_loss, mse_loss_ghost_resid + return x_hat, f_x, loss, mse_loss, l1_loss, ghost_loss @torch.no_grad() - def initialize_b_dec(self, acts_store: CachedActivationsStore): + def initialize_b_dec(self, cfg: config.Config, acts_store: CachedActivationsStore): previous_b_dec = self.b_dec.clone().cpu() - assert isinstance(acts_store, CachedActivationsStore) - all_activations = get_sae_batches(self.cfg, acts_store).detach().cpu() - out = all_activations.mean(dim=0) + all_activations = get_sae_batches(cfg, acts_store).detach().cpu() + mean = all_activations.mean(dim=0) previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1) - distances = torch.norm(all_activations - out, dim=-1) + distances = torch.norm(all_activations - mean, dim=-1) - print("Reinitializing b_dec with mean of activations") - print( - f"Previous distances: {previous_distances.median(0).values.mean().item()}" - ) - print(f"New distances: {distances.median(0).values.mean().item()}") + print(f"Prev dist: {previous_distances.median(0).values.mean().item()}") + print(f"New dist: {distances.median(0).values.mean().item()}") - self.b_dec.data = out.to(self.dtype).to(self.device) + self.b_dec.data = mean.to(self.b_dec.dtype).to(self.b_dec.device) @torch.no_grad() def set_decoder_norm_to_unit_norm(self): @@ -586,109 +583,34 @@

    Instance variables

    def remove_gradient_parallel_to_decoder_directions(self): """ Update grads so that they remove the parallel component - (d_sae, d_in) shape + (d_sae, d_vit) shape """ parallel_component = einops.einsum( self.W_dec.grad, self.W_dec.data, - "d_sae d_in, d_sae d_in -> d_sae", + "d_sae d_vit, d_sae d_vit -> d_sae", ) self.W_dec.grad -= einops.einsum( parallel_component, self.W_dec.data, - "d_sae, d_sae d_in -> d_sae d_in", - ) - - def save_model(self, path: str): - """ - Basic save function for the model. Saves the model's state_dict and the config used to train it. - """ - - # check if path exists - folder = os.path.dirname(path) - os.makedirs(folder, exist_ok=True) - - state_dict = {"cfg": self.cfg, "state_dict": self.state_dict()} - - if path.endswith(".pt"): - torch.save(state_dict, path) - elif path.endswith("pkl.gz"): - with gzip.open(path, "wb") as f: - pickle.dump(state_dict, f) - else: - raise ValueError( - f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz" - ) - - print(f"Saved model to {path}") - - @classmethod - def load_from_pretrained(cls, path: str): - """ - Load function for the model. Loads the model's state_dict and the config used to train it. - This method can be called directly on the class, without needing an instance. - """ - - # Ensure the file exists - if not os.path.isfile(path): - raise FileNotFoundError(f"No file found at specified path: {path}") - - # Load the state dictionary - if path.endswith(".pt"): - try: - state_dict = torch.load(path, weights_only=False) - except Exception as e: - raise IOError(f"Error loading the state dictionary from .pt file: {e}") - - elif path.endswith(".pkl.gz"): - try: - with gzip.open(path, "rb") as f: - state_dict = pickle.load(f) - except Exception as e: - raise IOError( - f"Error loading the state dictionary from .pkl.gz file: {e}" - ) - elif path.endswith(".pkl"): - try: - with open(path, "rb") as f: - state_dict = pickle.load(f) - except Exception as e: - raise IOError(f"Error loading the state dictionary from .pkl file: {e}") - else: - raise ValueError( - f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz" - ) - - # Ensure the loaded state contains both 'cfg' and 'state_dict' - if "cfg" not in state_dict or "state_dict" not in state_dict: - raise ValueError( - "The loaded state dictionary must contain 'cfg' and 'state_dict' keys" - ) - - # Create an instance of the class using the loaded configuration - instance = cls(cfg=state_dict["cfg"]) - instance.load_state_dict(state_dict["state_dict"]) - - return instance - - def get_name(self): - assert isinstance(self.cfg, config.Config) - return f"sparse_autoencoder_{self.cfg.model}_{self.cfg.block_layer}_{self.cfg.module_name}_{self.cfg.d_sae}"
    + "d_sae, d_sae d_vit -> d_sae d_vit", + )

    Ancestors

    • torch.nn.modules.module.Module
    -

    Static methods

    +

    Class variables

    -
    -def load_from_pretrained(cls, path: str) -
    +
    var l1_coeff : float
    +
    +
    +
    +
    var use_ghost_grads : bool
    -

    Load function for the model. Loads the model's state_dict and the config used to train it. -This method can be called directly on the class, without needing an instance.

    +

    Methods

    @@ -699,14 +621,8 @@

    Methods

    -
    -def get_name(self) -
    -
    -
    -
    -def initialize_b_dec(self, acts_store: CachedActivationsStore) +def initialize_b_dec(self, cfg: Config, acts_store: CachedActivationsStore)
    @@ -716,13 +632,7 @@

    Methods

    Update grads so that they remove the parallel component -(d_sae, d_in) shape

    -
    -
    -def save_model(self, path: str) -
    -
    -

    Basic save function for the model. Saves the model's state_dict and the config used to train it.

    +(d_sae, d_vit) shape

    def set_decoder_norm_to_unit_norm(self) @@ -747,13 +657,13 @@

    Methods

  • Functions

  • @@ -780,8 +690,9 @@

  • Session

    -
      + diff --git a/docs/saev/sessions.html b/docs/saev/sessions.html deleted file mode 100644 index ff2beb3..0000000 --- a/docs/saev/sessions.html +++ /dev/null @@ -1,57 +0,0 @@ - - - - - - -saev.sessions API documentation - - - - - - - - - - - - - -
      -
      -
      -

      Module saev.sessions

      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      - -
      - - - diff --git a/docs/saev/vits.html b/docs/saev/vits.html deleted file mode 100644 index 6da39d6..0000000 --- a/docs/saev/vits.html +++ /dev/null @@ -1,57 +0,0 @@ - - - - - - -saev.vits API documentation - - - - - - - - - - - - - -
      -
      -
      -

      Module saev.vits

      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      -
      - -
      - - - diff --git a/saev/__init__.py b/saev/__init__.py index 786c82d..a9bd478 100644 --- a/saev/__init__.py +++ b/saev/__init__.py @@ -1,3 +1,9 @@ +""" +saev is a Python package for training sparse autoencoders (SAEs) on vision transformers (ViTs) in PyTorch. + +The main entrypoint to the package is in [main.py](https://github.com/samuelstevens/saev/blob/main/main.py); use `python main.py --help` to see the options and documentation for the script. +""" + from .config import Config __all__ = ["Config"] diff --git a/saev/sessions.py b/saev/sessions.py deleted file mode 100644 index 8b13789..0000000 --- a/saev/sessions.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/saev/vits.py b/saev/vits.py deleted file mode 100644 index 8b13789..0000000 --- a/saev/vits.py +++ /dev/null @@ -1 +0,0 @@ -