diff --git a/contrib/semseg/__main__.py b/contrib/semseg/__main__.py
index e6cd1da..f2ed6e1 100644
--- a/contrib/semseg/__main__.py
+++ b/contrib/semseg/__main__.py
@@ -43,8 +43,10 @@ def train(
@beartype.beartype
-def visuals():
- print("Not implemented.")
+def visuals(cfg: typing.Annotated[config.Visuals, tyro.conf.arg(name="")]):
+ from . import visuals
+
+ visuals.main(cfg)
if __name__ == "__main__":
diff --git a/contrib/semseg/config.py b/contrib/semseg/config.py
index cd5c59b..19a65ed 100644
--- a/contrib/semseg/config.py
+++ b/contrib/semseg/config.py
@@ -41,6 +41,30 @@ class Train:
log_to: str = os.path.join(".", "logs")
+@beartype.beartype
+@dataclasses.dataclass(frozen=True)
+class Visuals:
+ sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
+ """Path to the sae.pt file."""
+ acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
+ """Configuration for the saved ADE20K training ViT activations."""
+ imgs: saev.config.Ade20kDataset = dataclasses.field(
+ default_factory=lambda: saev.config.Ade20kDataset(split="training")
+ )
+ """Configuration for the ADE20K training dataset."""
+ batch_size: int = 128
+ """Batch size for calculating F1 scores."""
+ n_workers: int = 32
+ """Number of dataloader workers."""
+ label_threshold: float = 0.9
+ device: str = "cuda"
+ "Hardware for SAE inference." ""
+ ade20k_cls: int = 29
+ """ADE20K class to probe for."""
+ k: int = 32
+ """Top K features to save."""
+
+
@beartype.beartype
def grid(cfg: Train, sweep_dct: dict[str, object]) -> tuple[list[Train], list[str]]:
cfgs, errs = [], []
diff --git a/contrib/semseg/reproduce.md b/contrib/semseg/reproduce.md
index cbe235a..bcebfd9 100644
--- a/contrib/semseg/reproduce.md
+++ b/contrib/semseg/reproduce.md
@@ -1,3 +1,5 @@
+This sub-module reproduces the results from Section 4.2 of our paper.
+
# Overview
As an overview:
diff --git a/contrib/semseg/visuals.py b/contrib/semseg/visuals.py
new file mode 100644
index 0000000..149eb8e
--- /dev/null
+++ b/contrib/semseg/visuals.py
@@ -0,0 +1,140 @@
+"""
+Propose features for manual verification.
+"""
+
+from . import config, training
+import einops
+import beartype
+import torch
+import numpy as np
+from jaxtyping import jaxtyped, Int, Shaped
+
+import saev.nn
+import saev.helpers
+
+
+@beartype.beartype
+@torch.no_grad
+def main(cfg: config.Visuals):
+ sae = saev.nn.load(cfg.sae_ckpt)
+ sae = sae.to(cfg.device)
+
+ dataset = training.Dataset(cfg.acts, cfg.imgs)
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=cfg.batch_size,
+ num_workers=cfg.n_workers,
+ shuffle=False,
+ persistent_workers=(cfg.n_workers > 0),
+ )
+
+ tp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)
+ fp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)
+ fn = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device)
+
+ for batch in saev.helpers.progress(dataloader):
+ pixel_labels = einops.rearrange(
+ batch["pixel_labels"],
+ "batch (w pw) (h ph) -> batch w h (pw ph)",
+ # TODO: change from hard-coded values
+ pw=16,
+ ph=16,
+ )
+ unique, counts = axis_unique(pixel_labels.numpy(), null_value=0)
+
+ # TODO: change from hard-coded values
+ # 256 is 16x16
+ idx = counts[:, :, :, 0] > (256 * cfg.label_threshold)
+ acts = batch["acts"][idx].to(cfg.device)
+ labels = unique[idx][:, 0]
+
+ _, f_x, _ = sae(acts)
+
+ pred = f_x > 0
+ true = torch.from_numpy(labels == cfg.ade20k_cls).view(-1, 1).to(cfg.device)
+
+ tp += (pred & true).sum(axis=0)
+ fp += (pred & ~true).sum(axis=0)
+ fn += (~pred & true).sum(axis=0)
+
+ f1 = (2 * tp) / (2 * tp + fp + fn)
+ indices = f1.topk(cfg.k).indices.tolist()
+
+ breakpoint()
+
+ scale_mean_flag = (
+ "--data.scale-mean" if cfg.acts.scale_mean else "--data.no-scale-mean"
+ )
+ scale_norm_flag = (
+ "--data.scale-norm" if cfg.acts.scale_norm else "--data.no-scale-norm"
+ )
+
+ print("Run this command to save best images:")
+ print()
+ print(
+ f" uv run python -m saev visuals --ckpt {cfg.ckpt} --include-latents {' '.join(indices)} --data.shard-root {cfg.data.shard_root} {scale_mean_flag} {scale_norm_flag} images:ade20k-dataset --images.root {cfg.imgs.root} --images.split {cfg.imgs.split}"
+ )
+
+
+@jaxtyped(typechecker=beartype.beartype)
+def axis_unique(
+ a: Shaped[np.ndarray, "*axes"],
+ axis: int = -1,
+ return_counts: bool = True,
+ *,
+ null_value: int = -1,
+) -> (
+ Shaped[np.ndarray, "*axes"]
+ | tuple[Shaped[np.ndarray, "*axes"], Int[np.ndarray, "*axes"]]
+):
+ """
+ Calculate unique values and their counts along any axis of a matrix.
+
+ Arguments:
+ a: Input array
+ axis: The axis along which to find unique values.
+ return_counts: If true, also return the count of each unique value
+
+ Returns:
+ unique: Array of unique values, with zeros replacing duplicates
+ counts: (optional) Count of each unique value (only if return_counts=True)
+ """
+ assert isinstance(axis, int)
+
+ # Move the target axis to the end for consistent processing
+ a_transformed = np.moveaxis(a, axis, -1)
+
+ # Sort along the last axis
+ sorted_a = np.sort(a_transformed, axis=-1)
+
+ # Find duplicates
+ duplicates = sorted_a[..., 1:] == sorted_a[..., :-1]
+
+ # Create output array
+ unique = sorted_a.copy()
+ unique[..., 1:][duplicates] = null_value
+
+ if not return_counts:
+ # Move axis back to original position
+ return np.moveaxis(unique, -1, axis)
+
+ # Calculate counts
+ shape = list(a_transformed.shape)
+ count_matrix = np.zeros(shape, dtype=int)
+
+ # Process each slice along other dimensions
+ for idx in np.ndindex(*shape[:-1]):
+ slice_unique = unique[idx]
+ idxs = np.flatnonzero(slice_unique)
+ if len(idxs) > 0:
+ # Calculate counts using diff for intermediate positions
+ counts = np.diff(idxs)
+ count_matrix[idx][idxs[:-1]] = counts
+ # Handle the last unique value
+ count_matrix[idx][idxs[-1]] = shape[-1] - idxs[-1]
+
+ # Move axes back to original positions
+ unique = np.moveaxis(unique, -1, axis)
+ count_matrix = np.moveaxis(count_matrix, -1, axis)
+
+ return unique, count_matrix
diff --git a/docs/contrib/index.html b/docs/contrib/index.html
index a61c5e7..0db731e 100644
--- a/docs/contrib/index.html
+++ b/docs/contrib/index.html
@@ -33,7 +33,7 @@
contrib.semseg
-
+This sub-module reproduces the results from Section 4.2 of our paper …
diff --git a/docs/contrib/semseg/config.html b/docs/contrib/semseg/config.html
index 1350690..91cdac9 100644
--- a/docs/contrib/semseg/config.html
+++ b/docs/contrib/semseg/config.html
@@ -147,6 +147,79 @@ Class variables
+
+class Visuals
+( sae_ckpt: str = './checkpoints/sae.pt', acts: DataLoad = <factory>, imgs: Ade20kDataset = <factory>, batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)
+
+
+Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)
+
+
+Expand source code
+
+@beartype.beartype
+@dataclasses.dataclass(frozen=True)
+class Visuals:
+ sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
+ """Path to the sae.pt file."""
+ acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
+ """Configuration for the saved ADE20K training ViT activations."""
+ imgs: saev.config.Ade20kDataset = dataclasses.field(
+ default_factory=lambda: saev.config.Ade20kDataset(split="training")
+ )
+ """Configuration for the ADE20K training dataset."""
+ batch_size: int = 128
+ """Batch size for calculating F1 scores."""
+ n_workers: int = 32
+ """Number of dataloader workers."""
+ label_threshold: float = 0.9
+ device: str = "cuda"
+ "Hardware for SAE inference." ""
+ ade20k_cls: int = 29
+ """ADE20K class to probe for."""
+ k: int = 32
+ """Top K features to save."""
+
+Class variables
+
+var acts : DataLoad
+
+Configuration for the saved ADE20K training ViT activations.
+
+var ade20k_cls : int
+
+ADE20K class to probe for.
+
+var batch_size : int
+
+Batch size for calculating F1 scores.
+
+var device : str
+
+Hardware for SAE inference.
+
+var imgs : Ade20kDataset
+
+Configuration for the ADE20K training dataset.
+
+var k : int
+
+
+
+var label_threshold : float
+
+
+
+var n_workers : int
+
+Number of dataloader workers.
+
+var sae_ckpt : str
+
+
+
+
+
@@ -185,6 +258,20 @@
+
+
+
diff --git a/docs/contrib/semseg/index.html b/docs/contrib/semseg/index.html
index eff3643..bdafbd5 100644
--- a/docs/contrib/semseg/index.html
+++ b/docs/contrib/semseg/index.html
@@ -5,7 +5,7 @@
contrib.semseg API documentation
-
+
@@ -27,6 +27,7 @@
Module contrib.semseg
+This sub-module reproduces the results from Section 4.2 of our paper.
Overview
As an overview:
@@ -66,6 +67,10 @@
+contrib.semseg.visuals
+
+Propose features for manual verification.
+
@@ -98,6 +103,7 @@
contrib.semseg.dashboard2
contrib.semseg.interactive
contrib.semseg.training
+contrib.semseg.visuals
diff --git a/docs/contrib/semseg/visuals.html b/docs/contrib/semseg/visuals.html
new file mode 100644
index 0000000..47f3ef6
--- /dev/null
+++ b/docs/contrib/semseg/visuals.html
@@ -0,0 +1,90 @@
+
+
+
+
+
+
+contrib.semseg.visuals API documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Module contrib.semseg.visuals
+
+
+Propose features for manual verification.
+
+
+
+
+
+
+
+def axis_unique (a: jaxtyping.Shaped[ndarray, '*axes'], axis: int = -1, return_counts: bool = True, *, null_value: int = -1) ‑> jaxtyping.Shaped[ndarray, '*axes'] | tuple[jaxtyping.Shaped[ndarray, '*axes'], jaxtyping.Int[ndarray, '*axes']]
+
+
+Calculate unique values and their counts along any axis of a matrix.
+
Arguments
+
a: Input array
+axis: The axis along which to find unique values.
+return_counts: If true, also return the count of each unique value
+
Returns
+
+unique
+Array of unique values, with zeros replacing duplicates
+counts
+(optional) Count of each unique value (only if return_counts=True)
+
+
+
+def main (cfg: Visuals )
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/docs/index.html b/docs/index.html
index c72fc8f..6b33c65 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -35,8 +35,8 @@ Package Docs
Package to train SAEs for vision models.
-
probing
-
Package for probing for individual features in trained SAEs.
+
contrib
+
Individual sub-packages not related to the core package.
faithfulness
diff --git a/docs/llms.txt b/docs/llms.txt
index 59d52f8..2daaa4b 100644
--- a/docs/llms.txt
+++ b/docs/llms.txt
@@ -1401,6 +1401,8 @@ Sub-modules
Module contrib.semseg
=====================
+This sub-module reproduces the results from Section 4.2 of our paper.
+
# Overview
As an overview:
@@ -1429,6 +1431,7 @@ Sub-modules
* contrib.semseg.dashboard2
* contrib.semseg.interactive
* contrib.semseg.training
+* contrib.semseg.visuals
Module contrib.semseg.config
============================
@@ -1486,6 +1489,38 @@ Classes
`weight_decay: float`
: Weight decay for AdamW.
+`Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad =
, imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)`
+: Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)
+
+ ### Class variables
+
+ `acts: saev.config.DataLoad`
+ : Configuration for the saved ADE20K training ViT activations.
+
+ `ade20k_cls: int`
+ : ADE20K class to probe for.
+
+ `batch_size: int`
+ : Batch size for calculating F1 scores.
+
+ `device: str`
+ : Hardware for SAE inference.
+
+ `imgs: saev.config.Ade20kDataset`
+ : Configuration for the ADE20K training dataset.
+
+ `k: int`
+ : Top K features to save.
+
+ `label_threshold: float`
+ :
+
+ `n_workers: int`
+ : Number of dataloader workers.
+
+ `sae_ckpt: str`
+ : Path to the sae.pt file.
+
Module contrib.semseg.dashboard
===============================
@@ -1579,4 +1614,26 @@ Classes
### Instance variables
`d_vit: int`
- :
\ No newline at end of file
+ :
+
+Module contrib.semseg.visuals
+=============================
+Propose features for manual verification.
+
+Functions
+---------
+
+`axis_unique(a: jaxtyping.Shaped[ndarray, '*axes'], axis: int = -1, return_counts: bool = True, *, null_value: int = -1) ‑> jaxtyping.Shaped[ndarray, '*axes'] | tuple[jaxtyping.Shaped[ndarray, '*axes'], jaxtyping.Int[ndarray, '*axes']]`
+: Calculate unique values and their counts along any axis of a matrix.
+
+ Arguments:
+ a: Input array
+ axis: The axis along which to find unique values.
+ return_counts: If true, also return the count of each unique value
+
+ Returns:
+ unique: Array of unique values, with zeros replacing duplicates
+ counts: (optional) Count of each unique value (only if return_counts=True)
+
+`main(cfg: contrib.semseg.config.Visuals)`
+:
\ No newline at end of file