Skip to content

Commit

Permalink
find top k features for ade20k class
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Dec 2, 2024
1 parent 0a829ef commit 964ec75
Show file tree
Hide file tree
Showing 10 changed files with 415 additions and 7 deletions.
6 changes: 4 additions & 2 deletions contrib/semseg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def train(


@beartype.beartype
def visuals():
print("Not implemented.")
def visuals(cfg: typing.Annotated[config.Visuals, tyro.conf.arg(name="")]):
from . import visuals

visuals.main(cfg)


if __name__ == "__main__":
Expand Down
24 changes: 24 additions & 0 deletions contrib/semseg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ class Train:
log_to: str = os.path.join(".", "logs")


@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
"""Path to the sae.pt file."""
acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
"""Configuration for the saved ADE20K training ViT activations."""
imgs: saev.config.Ade20kDataset = dataclasses.field(
default_factory=lambda: saev.config.Ade20kDataset(split="training")
)
"""Configuration for the ADE20K training dataset."""
batch_size: int = 128
"""Batch size for calculating F1 scores."""
n_workers: int = 32
"""Number of dataloader workers."""
label_threshold: float = 0.9
device: str = "cuda"
"Hardware for SAE inference." ""
ade20k_cls: int = 29
"""ADE20K class to probe for."""
k: int = 32
"""Top K features to save."""


@beartype.beartype
def grid(cfg: Train, sweep_dct: dict[str, object]) -> tuple[list[Train], list[str]]:
cfgs, errs = [], []
Expand Down
2 changes: 2 additions & 0 deletions contrib/semseg/reproduce.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
This sub-module reproduces the results from Section 4.2 of our paper.

# Overview

As an overview:
Expand Down
140 changes: 140 additions & 0 deletions contrib/semseg/visuals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Propose features for manual verification.
"""

from . import config, training
import einops
import beartype
import torch
import numpy as np
from jaxtyping import jaxtyped, Int, Shaped

import saev.nn
import saev.helpers


@beartype.beartype
@torch.no_grad
def main(cfg: config.Visuals):
sae = saev.nn.load(cfg.sae_ckpt)
sae = sae.to(cfg.device)

dataset = training.Dataset(cfg.acts, cfg.imgs)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg.batch_size,
num_workers=cfg.n_workers,
shuffle=False,
persistent_workers=(cfg.n_workers > 0),
)

tp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)
fp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)
fn = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)

for batch in saev.helpers.progress(dataloader):
pixel_labels = einops.rearrange(
batch["pixel_labels"],
"batch (w pw) (h ph) -> batch w h (pw ph)",
# TODO: change from hard-coded values
pw=16,
ph=16,
)
unique, counts = axis_unique(pixel_labels.numpy(), null_value=0)

# TODO: change from hard-coded values
# 256 is 16x16
idx = counts[:, :, :, 0] > (256 * cfg.label_threshold)
acts = batch["acts"][idx].to(cfg.device)
labels = unique[idx][:, 0]

_, f_x, _ = sae(acts)

pred = f_x > 0
true = torch.from_numpy(labels == cfg.ade20k_cls).view(-1, 1).to(cfg.device)

tp += (pred & true).sum(axis=0)
fp += (pred & ~true).sum(axis=0)
fn += (~pred & true).sum(axis=0)

f1 = (2 * tp) / (2 * tp + fp + fn)
indices = f1.topk(cfg.k).indices.tolist()

breakpoint()

scale_mean_flag = (
"--data.scale-mean" if cfg.acts.scale_mean else "--data.no-scale-mean"
)
scale_norm_flag = (
"--data.scale-norm" if cfg.acts.scale_norm else "--data.no-scale-norm"
)

print("Run this command to save best images:")
print()
print(
f" uv run python -m saev visuals --ckpt {cfg.ckpt} --include-latents {' '.join(indices)} --data.shard-root {cfg.data.shard_root} {scale_mean_flag} {scale_norm_flag} images:ade20k-dataset --images.root {cfg.imgs.root} --images.split {cfg.imgs.split}"
)


@jaxtyped(typechecker=beartype.beartype)
def axis_unique(
a: Shaped[np.ndarray, "*axes"],
axis: int = -1,
return_counts: bool = True,
*,
null_value: int = -1,
) -> (
Shaped[np.ndarray, "*axes"]
| tuple[Shaped[np.ndarray, "*axes"], Int[np.ndarray, "*axes"]]
):
"""
Calculate unique values and their counts along any axis of a matrix.
Arguments:
a: Input array
axis: The axis along which to find unique values.
return_counts: If true, also return the count of each unique value
Returns:
unique: Array of unique values, with zeros replacing duplicates
counts: (optional) Count of each unique value (only if return_counts=True)
"""
assert isinstance(axis, int)

# Move the target axis to the end for consistent processing
a_transformed = np.moveaxis(a, axis, -1)

# Sort along the last axis
sorted_a = np.sort(a_transformed, axis=-1)

# Find duplicates
duplicates = sorted_a[..., 1:] == sorted_a[..., :-1]

# Create output array
unique = sorted_a.copy()
unique[..., 1:][duplicates] = null_value

if not return_counts:
# Move axis back to original position
return np.moveaxis(unique, -1, axis)

# Calculate counts
shape = list(a_transformed.shape)
count_matrix = np.zeros(shape, dtype=int)

# Process each slice along other dimensions
for idx in np.ndindex(*shape[:-1]):
slice_unique = unique[idx]
idxs = np.flatnonzero(slice_unique)
if len(idxs) > 0:
# Calculate counts using diff for intermediate positions
counts = np.diff(idxs)
count_matrix[idx][idxs[:-1]] = counts
# Handle the last unique value
count_matrix[idx][idxs[-1]] = shape[-1] - idxs[-1]

# Move axes back to original positions
unique = np.moveaxis(unique, -1, axis)
count_matrix = np.moveaxis(count_matrix, -1, axis)

return unique, count_matrix
2 changes: 1 addition & 1 deletion docs/contrib/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dl>
<dt><code class="name"><a title="contrib.semseg" href="semseg/index.html">contrib.semseg</a></code></dt>
<dd>
<div class="desc"><p>Overview</p></div>
<div class="desc"><p>This sub-module reproduces the results from Section 4.2 of our paper</p></div>
</dd>
</dl>
</section>
Expand Down
87 changes: 87 additions & 0 deletions docs/contrib/semseg/config.html
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,79 @@ <h3>Class variables</h3>
</dd>
</dl>
</dd>
<dt id="contrib.semseg.config.Visuals"><code class="flex name class">
<span>class <span class="ident">Visuals</span></span>
<span>(</span><span>sae_ckpt: str = './checkpoints/sae.pt', acts: <a title="saev.config.DataLoad" href="../../saev/config.html#saev.config.DataLoad">DataLoad</a> = &lt;factory&gt;, imgs: <a title="saev.config.Ade20kDataset" href="../../saev/config.html#saev.config.Ade20kDataset">Ade20kDataset</a> = &lt;factory&gt;, batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)</span>
</code></dt>
<dd>
<div class="desc"><p>Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = <factory>, imgs: saev.config.Ade20kDataset = <factory>, batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)</p></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">@beartype.beartype
@dataclasses.dataclass(frozen=True)
class Visuals:
sae_ckpt: str = os.path.join(&#34;.&#34;, &#34;checkpoints&#34;, &#34;sae.pt&#34;)
&#34;&#34;&#34;Path to the sae.pt file.&#34;&#34;&#34;
acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
&#34;&#34;&#34;Configuration for the saved ADE20K training ViT activations.&#34;&#34;&#34;
imgs: saev.config.Ade20kDataset = dataclasses.field(
default_factory=lambda: saev.config.Ade20kDataset(split=&#34;training&#34;)
)
&#34;&#34;&#34;Configuration for the ADE20K training dataset.&#34;&#34;&#34;
batch_size: int = 128
&#34;&#34;&#34;Batch size for calculating F1 scores.&#34;&#34;&#34;
n_workers: int = 32
&#34;&#34;&#34;Number of dataloader workers.&#34;&#34;&#34;
label_threshold: float = 0.9
device: str = &#34;cuda&#34;
&#34;Hardware for SAE inference.&#34; &#34;&#34;
ade20k_cls: int = 29
&#34;&#34;&#34;ADE20K class to probe for.&#34;&#34;&#34;
k: int = 32
&#34;&#34;&#34;Top K features to save.&#34;&#34;&#34;</code></pre>
</details>
<h3>Class variables</h3>
<dl>
<dt id="contrib.semseg.config.Visuals.acts"><code class="name">var <span class="ident">acts</span><a title="saev.config.DataLoad" href="../../saev/config.html#saev.config.DataLoad">DataLoad</a></code></dt>
<dd>
<div class="desc"><p>Configuration for the saved ADE20K training ViT activations.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.ade20k_cls"><code class="name">var <span class="ident">ade20k_cls</span> : int</code></dt>
<dd>
<div class="desc"><p>ADE20K class to probe for.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.batch_size"><code class="name">var <span class="ident">batch_size</span> : int</code></dt>
<dd>
<div class="desc"><p>Batch size for calculating F1 scores.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.device"><code class="name">var <span class="ident">device</span> : str</code></dt>
<dd>
<div class="desc"><p>Hardware for SAE inference.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.imgs"><code class="name">var <span class="ident">imgs</span><a title="saev.config.Ade20kDataset" href="../../saev/config.html#saev.config.Ade20kDataset">Ade20kDataset</a></code></dt>
<dd>
<div class="desc"><p>Configuration for the ADE20K training dataset.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.k"><code class="name">var <span class="ident">k</span> : int</code></dt>
<dd>
<div class="desc"><p>Top K features to save.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.label_threshold"><code class="name">var <span class="ident">label_threshold</span> : float</code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt id="contrib.semseg.config.Visuals.n_workers"><code class="name">var <span class="ident">n_workers</span> : int</code></dt>
<dd>
<div class="desc"><p>Number of dataloader workers.</p></div>
</dd>
<dt id="contrib.semseg.config.Visuals.sae_ckpt"><code class="name">var <span class="ident">sae_ckpt</span> : str</code></dt>
<dd>
<div class="desc"><p>Path to the sae.pt file.</p></div>
</dd>
</dl>
</dd>
</dl>
</section>
</article>
Expand Down Expand Up @@ -185,6 +258,20 @@ <h4><code><a title="contrib.semseg.config.Train" href="#contrib.semseg.config.Tr
<li><code><a title="contrib.semseg.config.Train.weight_decay" href="#contrib.semseg.config.Train.weight_decay">weight_decay</a></code></li>
</ul>
</li>
<li>
<h4><code><a title="contrib.semseg.config.Visuals" href="#contrib.semseg.config.Visuals">Visuals</a></code></h4>
<ul class="two-column">
<li><code><a title="contrib.semseg.config.Visuals.acts" href="#contrib.semseg.config.Visuals.acts">acts</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.ade20k_cls" href="#contrib.semseg.config.Visuals.ade20k_cls">ade20k_cls</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.batch_size" href="#contrib.semseg.config.Visuals.batch_size">batch_size</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.device" href="#contrib.semseg.config.Visuals.device">device</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.imgs" href="#contrib.semseg.config.Visuals.imgs">imgs</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.k" href="#contrib.semseg.config.Visuals.k">k</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.label_threshold" href="#contrib.semseg.config.Visuals.label_threshold">label_threshold</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.n_workers" href="#contrib.semseg.config.Visuals.n_workers">n_workers</a></code></li>
<li><code><a title="contrib.semseg.config.Visuals.sae_ckpt" href="#contrib.semseg.config.Visuals.sae_ckpt">sae_ckpt</a></code></li>
</ul>
</li>
</ul>
</li>
</ul>
Expand Down
8 changes: 7 additions & 1 deletion docs/contrib/semseg/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1">
<meta name="generator" content="pdoc3 0.11.1">
<title>contrib.semseg API documentation</title>
<meta name="description" content="Overview">
<meta name="description" content="This sub-module reproduces the results from Section 4.2 of our paper">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/sanitize.min.css" integrity="sha512-y1dtMcuvtTMJc1yPgEqF0ZjQbhnc/bFhyvIyVNb9Zk5mIGtqVaAB1Ttl28su8AvFMOY0EwRbAe+HCLqj6W7/KA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/13.0.0/typography.min.css" integrity="sha512-Y1DYSb995BAfxobCkKepB1BqJJTPrOp3zPL74AWFugHHmmdcvO+C48WLrUOlhGMc0QG7AE3f7gmvvcrmX2fDoA==" crossorigin>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/default.min.css" crossorigin>
Expand All @@ -27,6 +27,7 @@
<h1 class="title">Module <code>contrib.semseg</code></h1>
</header>
<section id="section-intro">
<p>This sub-module reproduces the results from Section 4.2 of our paper.</p>
<h1 id="overview">Overview</h1>
<p>As an overview:</p>
<ol>
Expand Down Expand Up @@ -66,6 +67,10 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="contrib.semseg.visuals" href="visuals.html">contrib.semseg.visuals</a></code></dt>
<dd>
<div class="desc"><p>Propose features for manual verification.</p></div>
</dd>
</dl>
</section>
<section>
Expand Down Expand Up @@ -98,6 +103,7 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
<li><code><a title="contrib.semseg.dashboard2" href="dashboard2.html">contrib.semseg.dashboard2</a></code></li>
<li><code><a title="contrib.semseg.interactive" href="interactive/index.html">contrib.semseg.interactive</a></code></li>
<li><code><a title="contrib.semseg.training" href="training.html">contrib.semseg.training</a></code></li>
<li><code><a title="contrib.semseg.visuals" href="visuals.html">contrib.semseg.visuals</a></code></li>
</ul>
</li>
</ul>
Expand Down
Loading

0 comments on commit 964ec75

Please sign in to comment.