Skip to content

Commit

Permalink
Show example images on dashboard
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Nov 24, 2024
1 parent a1b55f6 commit 30d2ffd
Show file tree
Hide file tree
Showing 10 changed files with 491 additions and 349 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The first invocation should create a virtual environment and show a help message

See the [docs](https://samuelstevens.me/saev/) for an overview.

I recommend using the [llms.txt](https://samuelstevens.me/saev/llms.txt) file as a way to use any LLM provider to ask questions.
For example, you can run `curl https://samuelstevens.me/saev/llms.txt | pbcopy` on macOS to copy the text, then paste it into [https://claude.ai](https://claude.ai) and ask any question you have.

## Roadmap

Expand Down
192 changes: 90 additions & 102 deletions docs/llms.txt

Large diffs are not rendered by default.

292 changes: 150 additions & 142 deletions docs/saev/activations.html

Large diffs are not rendered by default.

72 changes: 63 additions & 9 deletions docs/saev/config.html
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ <h2 class="section-title" id="header-classes">Classes</h2>
<dl>
<dt id="saev.config.Activations"><code class="flex name class">
<span>class <span class="ident">Activations</span></span>
<span>(</span><span>data: <a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> &lt;factory&gt;, dump_to: str = './shards', model_org: Literal['clip', 'siglip', 'timm', 'dinov2'] = 'clip', model_ckpt: str = 'ViT-L-14/openai', vit_batch_size: int = 1024, n_workers: int = 8, d_vit: int = 1024, layers: list[int] = &lt;factory&gt;, n_patches_per_img: int = 256, cls_token: bool = True, n_patches_per_shard: int = 2400000, seed: int = 42, ssl: bool = True, device: str = 'cuda', slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')</span>
<span>(</span><span>data: <a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> <a title="saev.config.Ade20kDataset" href="#saev.config.Ade20kDataset">Ade20kDataset</a> = &lt;factory&gt;, dump_to: str = './shards', model_family: Literal['clip', 'siglip', 'dinov2'] = 'clip', model_ckpt: str = 'ViT-L-14/openai', vit_batch_size: int = 1024, n_workers: int = 8, d_vit: int = 1024, layers: list[int] = &lt;factory&gt;, n_patches_per_img: int = 256, cls_token: bool = True, n_patches_per_shard: int = 2400000, seed: int = 42, ssl: bool = True, device: str = 'cuda', slurm: bool = False, slurm_acct: str = 'PAS2136', log_to: str = './logs')</span>
</code></dt>
<dd>
<div class="desc"><p>Configuration for calculating and saving ViT activations.</p></div>
Expand All @@ -80,8 +80,8 @@ <h2 class="section-title" id="header-classes">Classes</h2>
&#34;&#34;&#34;Which dataset to use.&#34;&#34;&#34;
dump_to: str = os.path.join(&#34;.&#34;, &#34;shards&#34;)
&#34;&#34;&#34;Where to write shards.&#34;&#34;&#34;
model_org: typing.Literal[&#34;clip&#34;, &#34;siglip&#34;, &#34;timm&#34;, &#34;dinov2&#34;] = &#34;clip&#34;
&#34;&#34;&#34;Where to load models from.&#34;&#34;&#34;
model_family: typing.Literal[&#34;clip&#34;, &#34;siglip&#34;, &#34;dinov2&#34;] = &#34;clip&#34;
&#34;&#34;&#34;Which model family.&#34;&#34;&#34;
model_ckpt: str = &#34;ViT-L-14/openai&#34;
&#34;&#34;&#34;Specific model checkpoint.&#34;&#34;&#34;
vit_batch_size: int = 1024
Expand Down Expand Up @@ -124,7 +124,7 @@ <h3>Class variables</h3>
<dd>
<div class="desc"><p>Dimension of the ViT activations (depends on model).</p></div>
</dd>
<dt id="saev.config.Activations.data"><code class="name">var <span class="ident">data</span><a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a></code></dt>
<dt id="saev.config.Activations.data"><code class="name">var <span class="ident">data</span><a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> | <a title="saev.config.Ade20kDataset" href="#saev.config.Ade20kDataset">Ade20kDataset</a></code></dt>
<dd>
<div class="desc"><p>Which dataset to use.</p></div>
</dd>
Expand All @@ -148,9 +148,9 @@ <h3>Class variables</h3>
<dd>
<div class="desc"><p>Specific model checkpoint.</p></div>
</dd>
<dt id="saev.config.Activations.model_org"><code class="name">var <span class="ident">model_org</span> : Literal['clip', 'siglip', 'timm', 'dinov2']</code></dt>
<dt id="saev.config.Activations.model_family"><code class="name">var <span class="ident">model_family</span> : Literal['clip', 'siglip', 'dinov2']</code></dt>
<dd>
<div class="desc"><p>Where to load models from.</p></div>
<div class="desc"><p>Which model family.</p></div>
</dd>
<dt id="saev.config.Activations.n_patches_per_img"><code class="name">var <span class="ident">n_patches_per_img</span> : int</code></dt>
<dd>
Expand Down Expand Up @@ -186,6 +186,53 @@ <h3>Class variables</h3>
</dd>
</dl>
</dd>
<dt id="saev.config.Ade20kDataset"><code class="flex name class">
<span>class <span class="ident">Ade20kDataset</span></span>
<span>(</span><span>root: str = './data/split')</span>
</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Ade20kDataset:
&#34;&#34;&#34; &#34;&#34;&#34;

root: str = os.path.join(&#34;.&#34;, &#34;data&#34;, &#34;split&#34;)
&#34;&#34;&#34;Where the class folders with images are stored.&#34;&#34;&#34;

@property
def n_imgs(self) -&gt; int:
with open(os.path.join(self.root, &#34;sceneCategories.txt&#34;)) as fd:
return len(fd.read().split(&#34;\n&#34;))</code></pre>
</details>
<h3>Class variables</h3>
<dl>
<dt id="saev.config.Ade20kDataset.root"><code class="name">var <span class="ident">root</span> : str</code></dt>
<dd>
<div class="desc"><p>Where the class folders with images are stored.</p></div>
</dd>
</dl>
<h3>Instance variables</h3>
<dl>
<dt id="saev.config.Ade20kDataset.n_imgs"><code class="name">prop <span class="ident">n_imgs</span> : int</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">@property
def n_imgs(self) -&gt; int:
with open(os.path.join(self.root, &#34;sceneCategories.txt&#34;)) as fd:
return len(fd.read().split(&#34;\n&#34;))</code></pre>
</details>
</dd>
</dl>
</dd>
<dt id="saev.config.DataLoad"><code class="flex name class">
<span>class <span class="ident">DataLoad</span></span>
<span>(</span><span>shard_root: str = './shards', patches: Literal['cls', 'patches', 'meanpool'] = 'patches', layer: Union[int, Literal['all', 'meanpool']] = -2, clamp: float = 100000.0, n_random_samples: int = 524288, scale_mean: bool = True, scale_norm: bool = True)</span>
Expand Down Expand Up @@ -585,7 +632,7 @@ <h3>Class variables</h3>
</dd>
<dt id="saev.config.Visuals"><code class="flex name class">
<span>class <span class="ident">Visuals</span></span>
<span>(</span><span>ckpt: str = './checkpoints/sae.pt', data: <a title="saev.config.DataLoad" href="#saev.config.DataLoad">DataLoad</a> = &lt;factory&gt;, images: <a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> = &lt;factory&gt;, top_k: int = 128, n_workers: int = 16, topk_batch_size: int = 16384, sae_batch_size: int = 16384, epsilon: float = 1e-09, sort_by: Literal['cls', 'img', 'patch'] = 'patch', device: str = 'cuda', dump_to: str = './data', log_freq_range: tuple[float, float] = (-6.0, -2.0), log_value_range: tuple[float, float] = (-1.0, 1.0), include_latents: list[int] = &lt;factory&gt;)</span>
<span>(</span><span>ckpt: str = './checkpoints/sae.pt', data: <a title="saev.config.DataLoad" href="#saev.config.DataLoad">DataLoad</a> = &lt;factory&gt;, images: <a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> <a title="saev.config.Ade20kDataset" href="#saev.config.Ade20kDataset">Ade20kDataset</a> = &lt;factory&gt;, top_k: int = 128, n_workers: int = 16, topk_batch_size: int = 16384, sae_batch_size: int = 16384, epsilon: float = 1e-09, sort_by: Literal['cls', 'img', 'patch'] = 'patch', device: str = 'cuda', dump_to: str = './data', log_freq_range: tuple[float, float] = (-6.0, -2.0), log_value_range: tuple[float, float] = (-1.0, 1.0), include_latents: list[int] = &lt;factory&gt;)</span>
</code></dt>
<dd>
<div class="desc"><p>Configuration for generating visuals from trained SAEs.</p></div>
Expand Down Expand Up @@ -673,7 +720,7 @@ <h3>Class variables</h3>
<dd>
<div class="desc"><p>Value to add to avoid log(0).</p></div>
</dd>
<dt id="saev.config.Visuals.images"><code class="name">var <span class="ident">images</span><a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a></code></dt>
<dt id="saev.config.Visuals.images"><code class="name">var <span class="ident">images</span><a title="saev.config.ImagenetDataset" href="#saev.config.ImagenetDataset">ImagenetDataset</a> | <a title="saev.config.ImageFolderDataset" href="#saev.config.ImageFolderDataset">ImageFolderDataset</a> | <a title="saev.config.Ade20kDataset" href="#saev.config.Ade20kDataset">Ade20kDataset</a></code></dt>
<dd>
<div class="desc"><p>Which images to use.</p></div>
</dd>
Expand Down Expand Up @@ -820,7 +867,7 @@ <h4><code><a title="saev.config.Activations" href="#saev.config.Activations">Act
<li><code><a title="saev.config.Activations.layers" href="#saev.config.Activations.layers">layers</a></code></li>
<li><code><a title="saev.config.Activations.log_to" href="#saev.config.Activations.log_to">log_to</a></code></li>
<li><code><a title="saev.config.Activations.model_ckpt" href="#saev.config.Activations.model_ckpt">model_ckpt</a></code></li>
<li><code><a title="saev.config.Activations.model_org" href="#saev.config.Activations.model_org">model_org</a></code></li>
<li><code><a title="saev.config.Activations.model_family" href="#saev.config.Activations.model_family">model_family</a></code></li>
<li><code><a title="saev.config.Activations.n_patches_per_img" href="#saev.config.Activations.n_patches_per_img">n_patches_per_img</a></code></li>
<li><code><a title="saev.config.Activations.n_patches_per_shard" href="#saev.config.Activations.n_patches_per_shard">n_patches_per_shard</a></code></li>
<li><code><a title="saev.config.Activations.n_workers" href="#saev.config.Activations.n_workers">n_workers</a></code></li>
Expand All @@ -832,6 +879,13 @@ <h4><code><a title="saev.config.Activations" href="#saev.config.Activations">Act
</ul>
</li>
<li>
<h4><code><a title="saev.config.Ade20kDataset" href="#saev.config.Ade20kDataset">Ade20kDataset</a></code></h4>
<ul class="">
<li><code><a title="saev.config.Ade20kDataset.n_imgs" href="#saev.config.Ade20kDataset.n_imgs">n_imgs</a></code></li>
<li><code><a title="saev.config.Ade20kDataset.root" href="#saev.config.Ade20kDataset.root">root</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="saev.config.DataLoad" href="#saev.config.DataLoad">DataLoad</a></code></h4>
<ul class="two-column">
<li><code><a title="saev.config.DataLoad.clamp" href="#saev.config.DataLoad.clamp">clamp</a></code></li>
Expand Down
1 change: 1 addition & 0 deletions docs/saev/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ <h2 id="visualize-the-learned-features">Visualize the Learned Features</h2>
<p>This will record the top 128 patches, and then save the unique images among those top 128 patches for each feature in the trained SAE.
It will cache these best activations to disk, then start saving images to visualize later on.</p>
<p><code><a title="saev.webapp" href="webapp.html">saev.webapp</a></code> is a small web application based on <a href="https://marimo.io/">marimo</a> to interactively look at these images.</p>
<p>You can run it with <code>uv run marimo edit saev/webapp.py</code>.</p>
<h2 id="sweeps">Sweeps</h2>
<div class="admonition todo">
<p class="admonition-title">TODO</p>
Expand Down
33 changes: 19 additions & 14 deletions docs/saev/nn.html
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ <h2 class="section-title" id="header-classes">Classes</h2>
<span>(</span><span>mse: jaxtyping.Float[Tensor, ''], sparsity: jaxtyping.Float[Tensor, ''], ghost_grad: jaxtyping.Float[Tensor, ''], l0: jaxtyping.Float[Tensor, ''], l1: jaxtyping.Float[Tensor, ''])</span>
</code></dt>
<dd>
<div class="desc"><p>Loss(mse, sparsity, ghost_grad, l0, l1)</p></div>
<div class="desc"><p>The composite loss terms for an autoencoder training batch.</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">class Loss(typing.NamedTuple):
&#34;&#34;&#34;The composite loss terms for an autoencoder training batch.&#34;&#34;&#34;

mse: Float[Tensor, &#34;&#34;]
&#34;&#34;&#34;Reconstruction loss (mean squared error).&#34;&#34;&#34;
sparsity: Float[Tensor, &#34;&#34;]
Expand Down Expand Up @@ -173,8 +175,15 @@ <h3>Instance variables</h3>
self.logger = logging.getLogger(f&#34;sae(seed={cfg.seed})&#34;)

def forward(
self, x: Float[Tensor, &#34;batch d_model&#34;], dead_neuron_mask: None = None
self, x: Float[Tensor, &#34;batch d_model&#34;]
) -&gt; tuple[Float[Tensor, &#34;batch d_model&#34;], Float[Tensor, &#34;batch d_sae&#34;], Loss]:
&#34;&#34;&#34;
Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss.

Arguments:
x: a batch of ViT activations.
&#34;&#34;&#34;

# Remove encoder bias as per Anthropic
h_pre = (
einops.einsum(
Expand Down Expand Up @@ -220,7 +229,9 @@ <h3>Instance variables</h3>

@torch.no_grad()
def normalize_w_dec(self):
# Make sure the W_dec is still unit-norm
&#34;&#34;&#34;
Set W_dec to unit-norm columns.
&#34;&#34;&#34;
if self.cfg.normalize_w_dec:
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)

Expand Down Expand Up @@ -259,18 +270,12 @@ <h3>Class variables</h3>
<h3>Methods</h3>
<dl>
<dt id="saev.nn.SparseAutoencoder.forward"><code class="name flex">
<span>def <span class="ident">forward</span></span>(<span>self, x: jaxtyping.Float[Tensor, 'batch d_model'], dead_neuron_mask: None = None) ‑> tuple[jaxtyping.Float[Tensor, 'batch d_model'], jaxtyping.Float[Tensor, 'batch d_sae'], <a title="saev.nn.Loss" href="#saev.nn.Loss">Loss</a>]</span>
<span>def <span class="ident">forward</span></span>(<span>self, x: jaxtyping.Float[Tensor, 'batch d_model']) ‑> tuple[jaxtyping.Float[Tensor, 'batch d_model'], jaxtyping.Float[Tensor, 'batch d_sae'], <a title="saev.nn.Loss" href="#saev.nn.Loss">Loss</a>]</span>
</code></dt>
<dd>
<div class="desc"><p>Define the computation performed at every call.</p>
<p>Should be overridden by all subclasses.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Although the recipe for forward pass needs to be defined within
this function, one should call the :class:<code>Module</code> instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.</p>
</div></div>
<div class="desc"><p>Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss.</p>
<h2 id="arguments">Arguments</h2>
<p>x: a batch of ViT activations.</p></div>
</dd>
<dt id="saev.nn.SparseAutoencoder.init_b_dec"><code class="name flex">
<span>def <span class="ident">init_b_dec</span></span>(<span>self, vit_acts: jaxtyping.Float[Tensor, 'n d_vit'])</span>
Expand All @@ -282,7 +287,7 @@ <h3>Methods</h3>
<span>def <span class="ident">normalize_w_dec</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"></div>
<div class="desc"><p>Set W_dec to unit-norm columns.</p></div>
</dd>
<dt id="saev.nn.SparseAutoencoder.remove_parallel_grads"><code class="name flex">
<span>def <span class="ident">remove_parallel_grads</span></span>(<span>self)</span>
Expand Down
20 changes: 16 additions & 4 deletions docs/saev/visuals.html
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,22 @@ <h2 id="returns">Returns</h2>
<span>def <span class="ident">get_new_topk</span></span>(<span>val1: jaxtyping.Float[Tensor, 'd_sae k'], i1: jaxtyping.Int[Tensor, 'd_sae k'], val2: jaxtyping.Float[Tensor, 'd_sae k'], i2: jaxtyping.Int[Tensor, 'd_sae k'], k: int) ‑> tuple[jaxtyping.Float[Tensor, 'd_sae k'], jaxtyping.Int[Tensor, 'd_sae k']]</span>
</code></dt>
<dd>
<div class="desc"><div class="admonition todo">
<p class="admonition-title">TODO</p>
<p>document this function.</p>
</div></div>
<div class="desc"><p>Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset.</p>
<h2 id="args">Args</h2>
<dl>
<dt><strong><code>val1</code></strong></dt>
<dd>top k original SAE values.</dd>
<dt><strong><code>i1</code></strong></dt>
<dd>the patch indices of those original top k values.</dd>
<dt><strong><code>val2</code></strong></dt>
<dd>top k incoming SAE values.</dd>
<dt><strong><code>i2</code></strong></dt>
<dd>the patch indices of those incoming top k values.</dd>
<dt><strong><code>k</code></strong></dt>
<dd>k.</dd>
</dl>
<h2 id="returns">Returns</h2>
<p>The new top k values and their patch indices.</p></div>
</dd>
<dt id="saev.visuals.get_sae_acts"><code class="name flex">
<span>def <span class="ident">get_sae_acts</span></span>(<span>vit_acts: jaxtyping.Float[Tensor, 'n d_vit'], sae: <a title="saev.nn.SparseAutoencoder" href="nn.html#saev.nn.SparseAutoencoder">SparseAutoencoder</a>, cfg: <a title="saev.config.Visuals" href="config.html#saev.config.Visuals">Visuals</a>) ‑> jaxtyping.Float[Tensor, 'n d_sae']</span>
Expand Down
Loading

0 comments on commit 30d2ffd

Please sign in to comment.