Skip to content

Commit

Permalink
Updating docs
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Nov 21, 2024
1 parent 578e5ee commit 0adacd8
Show file tree
Hide file tree
Showing 22 changed files with 152 additions and 4,080 deletions.
55 changes: 13 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,53 +14,24 @@ Read [logbook.md](logbook.md) for a detailed log of my thought process.
See [related-work.md](related-work.md) for a list of works training SAEs on vision models.
Please open an issue or a PR if there is missing work.

## Using `saev`

This section is a placeholder right now.
The [docs](https://samuelstevens.me/saev/) should be better (maybe in the future).

I use [uv](https://docs.astral.sh/uv/) for everything now.
## Installation

Generate the activations using a ViT-B/32 pre-trained with the CLIP objective:
Installation is supported with [uv](https://docs.astral.sh/uv/).
saev will likely work with pure pip, conda, etc. but I will not formally support it.

```sh
uv run main.py activations \
--n-workers 12 \
--vit-batch-size 512 \
--d-vit 768 \
--n-layers 3 \
--n-patches-per-img 49 \
--model-ckpt ViT-B-32/openai \
--dump-to /local/scratch/stevens.994/cache/saev \
data:imagenet-dataset \
--data.split train
To install, clone this repository (maybe fork it first if you want).

```
In the project root directory, run `uv run python -m saev --help`.
The first invocation should create a virtual environment and show a help message.

Sweep LR for 10M patches on the second-to-last layer of the [CLS] token.

```sh
uv run main.py sweep \
--sweep configs/baseline.toml \
--n-patches 10_000_000 \
--data.shard-root /local/scratch/stevens.994/cache/saev/4dc22752a94c350ea6045599290cfbc31e3ee96b213d485318e434362b3bbdda \
--data.patches cls \
--data.layer -2
```

Generate webapp images:
## Using `saev`

```sh
uv run main.py webapp \
--ckpt ./checkpoints/cr6sl257/sae.pt \
--data.shard-root /local/scratch/stevens.994/cache/saev/4dc22752a94c350ea6045599290cfbc31e3ee96b213d485318e434362b3bbdda \
--dump-to /local/scratch/stevens.994/cache/saev/webapp/cr6sl257
```
See the [docs](https://samuelstevens.me/saev/) for an overview.

Then run the webapp with:

```sh
uv run marimo edit webapp.py
```
## Roadmap

And make sure `webapp_dir` is `"/local/scratch/stevens.994/cache/saev/webapp/cr6sl257"`.
1. Train models with data scaling (norm, mean) turned on.
2. Train models on ViT-L/14 datasets.
3. Semantic segmentation baseline with linear probe.
4. ADE20K experiment to demonstrate faithfulness.
4 changes: 4 additions & 0 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ <h2>Package Docs</h2>
<dt><a href="probing">probing</a></dt>
<dd><p>Package for probing for individual features in trained SAEs.</p></dd>
</div>
<div class="flex">
<dt><a href="faithfulness">faithfulness</a></dt>
<dd><p>Package for measuring SAE feature faithfulness through feature manipulation.</p></dd>
</div>
</dl>
</article>
</main>
Expand Down
84 changes: 1 addition & 83 deletions docs/saev/activations.html
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,6 @@ <h2 id="args">Args</h2>
<h2 id="returns">Returns</h2>
<p>Directory to where activations should be dumped/loaded from.</p></div>
</dd>
<dt id="saev.activations.get_broden_dataloader"><code class="name flex">
<span>def <span class="ident">get_broden_dataloader</span></span>(<span>cfg: <a title="saev.config.Activations" href="config.html#saev.config.Activations">Activations</a>, preprocess) ‑> torch.utils.data.dataloader.DataLoader</span>
</code></dt>
<dd>
<div class="desc"><p>Get a dataloader for Broden dataset.</p>
<h2 id="args">Args</h2>
<dl>
<dt><strong><code>cfg</code></strong></dt>
<dd>Config.</dd>
<dt><strong><code>preprocess</code></strong></dt>
<dd>Image transform to be applied to each image.</dd>
</dl>
<h2 id="returns">Returns</h2>
<p>A PyTorch Dataloader that yields dictionaries with <code>'image'</code> keys containing image batches and <code>'index'</code> keys containing original dataset indices.</p></div>
</dd>
<dt id="saev.activations.get_dataloader"><code class="name flex">
<span>def <span class="ident">get_dataloader</span></span>(<span>cfg: <a title="saev.config.Activations" href="config.html#saev.config.Activations">Activations</a>, preprocess)</span>
</code></dt>
Expand All @@ -97,7 +82,7 @@ <h2 id="returns">Returns</h2>
<p>A PyTorch Dataloader that yields dictionaries with <code>'image'</code> keys containing image batches.</p></div>
</dd>
<dt id="saev.activations.get_dataset"><code class="name flex">
<span>def <span class="ident">get_dataset</span></span>(<span>cfg: <a title="saev.config.ImagenetDataset" href="config.html#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.TreeOfLifeDataset" href="config.html#saev.config.TreeOfLifeDataset">TreeOfLifeDataset</a> | <a title="saev.config.LaionDataset" href="config.html#saev.config.LaionDataset">LaionDataset</a> | <a title="saev.config.ImageFolderDataset" href="config.html#saev.config.ImageFolderDataset">ImageFolderDataset</a> | <a title="saev.config.BrodenDataset" href="config.html#saev.config.BrodenDataset">BrodenDataset</a>, *, transform)</span>
<span>def <span class="ident">get_dataset</span></span>(<span>cfg: <a title="saev.config.ImagenetDataset" href="config.html#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.TreeOfLifeDataset" href="config.html#saev.config.TreeOfLifeDataset">TreeOfLifeDataset</a> | <a title="saev.config.LaionDataset" href="config.html#saev.config.LaionDataset">LaionDataset</a> | <a title="saev.config.ImageFolderDataset" href="config.html#saev.config.ImageFolderDataset">ImageFolderDataset</a>, *, transform)</span>
</code></dt>
<dd>
<div class="desc"><p>Gets the dataset for the current experiment; delegates construction to dataset-specific functions.</p>
Expand Down Expand Up @@ -167,12 +152,6 @@ <h2 id="returns">Returns</h2>
<dd>
<div class="desc"><p>Run dataset-specific setup. These setup functions can assume they are the only job running, but they should be idempotent; they should be safe (and ideally cheap) to run multiple times in a row.</p></div>
</dd>
<dt id="saev.activations.setup_broden"><code class="name flex">
<span>def <span class="ident">setup_broden</span></span>(<span>cfg: <a title="saev.config.Activations" href="config.html#saev.config.Activations">Activations</a>)</span>
</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="saev.activations.setup_imagefolder"><code class="name flex">
<span>def <span class="ident">setup_imagefolder</span></span>(<span>cfg: <a title="saev.config.Activations" href="config.html#saev.config.Activations">Activations</a>)</span>
</code></dt>
Expand Down Expand Up @@ -873,62 +852,6 @@ <h3>Methods</h3>
</dd>
</dl>
</dd>
<dt id="saev.activations.PreprocessedBroden"><code class="flex name class">
<span>class <span class="ident">PreprocessedBroden</span></span>
<span>(</span><span>cfg: <a title="saev.config.BrodenDataset" href="config.html#saev.config.BrodenDataset">BrodenDataset</a>, transform)</span>
</code></dt>
<dd>
<div class="desc"><p>An abstract class representing a :class:<code><a title="saev.activations.Dataset" href="#saev.activations.Dataset">Dataset</a></code>.</p>
<p>All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:<code>__getitem__</code>, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:<code>__len__</code>, which is expected to return the size of the dataset by many
:class:<code>~torch.utils.data.Sampler</code> implementations and the default options
of :class:<code>~torch.utils.data.DataLoader</code>. Subclasses could also
optionally implement :meth:<code>__getitems__</code>, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>:class:<code>~torch.utils.data.DataLoader</code> by default constructs an index
sampler that yields integral indices.
To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.</p>
</div></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">@beartype.beartype
class PreprocessedBroden(torch.utils.data.Dataset):
def __init__(self, cfg: config.BrodenDataset, transform):
import csv

self.cfg = cfg
self.transform = transform

self.samples = []

with open(os.path.join(cfg.root, &#34;index.csv&#34;)) as fd:
for row in csv.DictReader(fd):
self.samples.append(row[&#34;image&#34;])

def __getitem__(self, i):
fpath = os.path.join(self.cfg.root, &#34;images&#34;, self.samples[i])
with open(fpath, &#34;rb&#34;) as fd:
img = Image.open(fd).convert(&#34;RGB&#34;)
img = self.transform(img)
return {&#34;image&#34;: img, &#34;index&#34;: i}

def __len__(self) -&gt; int:
return len(self.samples)</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
<li>torch.utils.data.dataset.Dataset</li>
<li>typing.Generic</li>
</ul>
</dd>
<dt id="saev.activations.ShardWriter"><code class="flex name class">
<span>class <span class="ident">ShardWriter</span></span>
<span>(</span><span>cfg: <a title="saev.config.Activations" href="config.html#saev.config.Activations">Activations</a>)</span>
Expand Down Expand Up @@ -1552,15 +1475,13 @@ <h3>Methods</h3>
<ul class="">
<li><code><a title="saev.activations.dump" href="#saev.activations.dump">dump</a></code></li>
<li><code><a title="saev.activations.get_acts_dir" href="#saev.activations.get_acts_dir">get_acts_dir</a></code></li>
<li><code><a title="saev.activations.get_broden_dataloader" href="#saev.activations.get_broden_dataloader">get_broden_dataloader</a></code></li>
<li><code><a title="saev.activations.get_dataloader" href="#saev.activations.get_dataloader">get_dataloader</a></code></li>
<li><code><a title="saev.activations.get_dataset" href="#saev.activations.get_dataset">get_dataset</a></code></li>
<li><code><a title="saev.activations.get_default_dataloader" href="#saev.activations.get_default_dataloader">get_default_dataloader</a></code></li>
<li><code><a title="saev.activations.get_laion_dataloader" href="#saev.activations.get_laion_dataloader">get_laion_dataloader</a></code></li>
<li><code><a title="saev.activations.get_tol_dataloader" href="#saev.activations.get_tol_dataloader">get_tol_dataloader</a></code></li>
<li><code><a title="saev.activations.make_vit" href="#saev.activations.make_vit">make_vit</a></code></li>
<li><code><a title="saev.activations.setup" href="#saev.activations.setup">setup</a></code></li>
<li><code><a title="saev.activations.setup_broden" href="#saev.activations.setup_broden">setup_broden</a></code></li>
<li><code><a title="saev.activations.setup_imagefolder" href="#saev.activations.setup_imagefolder">setup_imagefolder</a></code></li>
<li><code><a title="saev.activations.setup_imagenet" href="#saev.activations.setup_imagenet">setup_imagenet</a></code></li>
<li><code><a title="saev.activations.setup_laion" href="#saev.activations.setup_laion">setup_laion</a></code></li>
Expand Down Expand Up @@ -1619,9 +1540,6 @@ <h4><code><a title="saev.activations.Metadata" href="#saev.activations.Metadata"
</ul>
</li>
<li>
<h4><code><a title="saev.activations.PreprocessedBroden" href="#saev.activations.PreprocessedBroden">PreprocessedBroden</a></code></h4>
</li>
<li>
<h4><code><a title="saev.activations.ShardWriter" href="#saev.activations.ShardWriter">ShardWriter</a></code></h4>
<ul class="two-column">
<li><code><a title="saev.activations.ShardWriter.acts" href="#saev.activations.ShardWriter.acts">acts</a></code></li>
Expand Down
Loading

0 comments on commit 0adacd8

Please sign in to comment.