diff --git a/contrib/semseg/__main__.py b/contrib/semseg/__main__.py index e6cd1da..f2ed6e1 100644 --- a/contrib/semseg/__main__.py +++ b/contrib/semseg/__main__.py @@ -43,8 +43,10 @@ def train( @beartype.beartype -def visuals(): - print("Not implemented.") +def visuals(cfg: typing.Annotated[config.Visuals, tyro.conf.arg(name="")]): + from . import visuals + + visuals.main(cfg) if __name__ == "__main__": diff --git a/contrib/semseg/config.py b/contrib/semseg/config.py index cd5c59b..19a65ed 100644 --- a/contrib/semseg/config.py +++ b/contrib/semseg/config.py @@ -41,6 +41,30 @@ class Train: log_to: str = os.path.join(".", "logs") +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Visuals: + sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt") + """Path to the sae.pt file.""" + acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad) + """Configuration for the saved ADE20K training ViT activations.""" + imgs: saev.config.Ade20kDataset = dataclasses.field( + default_factory=lambda: saev.config.Ade20kDataset(split="training") + ) + """Configuration for the ADE20K training dataset.""" + batch_size: int = 128 + """Batch size for calculating F1 scores.""" + n_workers: int = 32 + """Number of dataloader workers.""" + label_threshold: float = 0.9 + device: str = "cuda" + "Hardware for SAE inference." "" + ade20k_cls: int = 29 + """ADE20K class to probe for.""" + k: int = 32 + """Top K features to save.""" + + @beartype.beartype def grid(cfg: Train, sweep_dct: dict[str, object]) -> tuple[list[Train], list[str]]: cfgs, errs = [], [] diff --git a/contrib/semseg/reproduce.md b/contrib/semseg/reproduce.md index cbe235a..bcebfd9 100644 --- a/contrib/semseg/reproduce.md +++ b/contrib/semseg/reproduce.md @@ -1,3 +1,5 @@ +This sub-module reproduces the results from Section 4.2 of our paper. + # Overview As an overview: diff --git a/contrib/semseg/visuals.py b/contrib/semseg/visuals.py new file mode 100644 index 0000000..149eb8e --- /dev/null +++ b/contrib/semseg/visuals.py @@ -0,0 +1,140 @@ +""" +Propose features for manual verification. +""" + +from . import config, training +import einops +import beartype +import torch +import numpy as np +from jaxtyping import jaxtyped, Int, Shaped + +import saev.nn +import saev.helpers + + +@beartype.beartype +@torch.no_grad +def main(cfg: config.Visuals): + sae = saev.nn.load(cfg.sae_ckpt) + sae = sae.to(cfg.device) + + dataset = training.Dataset(cfg.acts, cfg.imgs) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=cfg.batch_size, + num_workers=cfg.n_workers, + shuffle=False, + persistent_workers=(cfg.n_workers > 0), + ) + + tp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device) + fp = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device) + fn = torch.zeros((sae.cfg.d_sae,), dtype=int, device=cfg.device) + + for batch in saev.helpers.progress(dataloader): + pixel_labels = einops.rearrange( + batch["pixel_labels"], + "batch (w pw) (h ph) -> batch w h (pw ph)", + # TODO: change from hard-coded values + pw=16, + ph=16, + ) + unique, counts = axis_unique(pixel_labels.numpy(), null_value=0) + + # TODO: change from hard-coded values + # 256 is 16x16 + idx = counts[:, :, :, 0] > (256 * cfg.label_threshold) + acts = batch["acts"][idx].to(cfg.device) + labels = unique[idx][:, 0] + + _, f_x, _ = sae(acts) + + pred = f_x > 0 + true = torch.from_numpy(labels == cfg.ade20k_cls).view(-1, 1).to(cfg.device) + + tp += (pred & true).sum(axis=0) + fp += (pred & ~true).sum(axis=0) + fn += (~pred & true).sum(axis=0) + + f1 = (2 * tp) / (2 * tp + fp + fn) + indices = f1.topk(cfg.k).indices.tolist() + + breakpoint() + + scale_mean_flag = ( + "--data.scale-mean" if cfg.acts.scale_mean else "--data.no-scale-mean" + ) + scale_norm_flag = ( + "--data.scale-norm" if cfg.acts.scale_norm else "--data.no-scale-norm" + ) + + print("Run this command to save best images:") + print() + print( + f" uv run python -m saev visuals --ckpt {cfg.ckpt} --include-latents {' '.join(indices)} --data.shard-root {cfg.data.shard_root} {scale_mean_flag} {scale_norm_flag} images:ade20k-dataset --images.root {cfg.imgs.root} --images.split {cfg.imgs.split}" + ) + + +@jaxtyped(typechecker=beartype.beartype) +def axis_unique( + a: Shaped[np.ndarray, "*axes"], + axis: int = -1, + return_counts: bool = True, + *, + null_value: int = -1, +) -> ( + Shaped[np.ndarray, "*axes"] + | tuple[Shaped[np.ndarray, "*axes"], Int[np.ndarray, "*axes"]] +): + """ + Calculate unique values and their counts along any axis of a matrix. + + Arguments: + a: Input array + axis: The axis along which to find unique values. + return_counts: If true, also return the count of each unique value + + Returns: + unique: Array of unique values, with zeros replacing duplicates + counts: (optional) Count of each unique value (only if return_counts=True) + """ + assert isinstance(axis, int) + + # Move the target axis to the end for consistent processing + a_transformed = np.moveaxis(a, axis, -1) + + # Sort along the last axis + sorted_a = np.sort(a_transformed, axis=-1) + + # Find duplicates + duplicates = sorted_a[..., 1:] == sorted_a[..., :-1] + + # Create output array + unique = sorted_a.copy() + unique[..., 1:][duplicates] = null_value + + if not return_counts: + # Move axis back to original position + return np.moveaxis(unique, -1, axis) + + # Calculate counts + shape = list(a_transformed.shape) + count_matrix = np.zeros(shape, dtype=int) + + # Process each slice along other dimensions + for idx in np.ndindex(*shape[:-1]): + slice_unique = unique[idx] + idxs = np.flatnonzero(slice_unique) + if len(idxs) > 0: + # Calculate counts using diff for intermediate positions + counts = np.diff(idxs) + count_matrix[idx][idxs[:-1]] = counts + # Handle the last unique value + count_matrix[idx][idxs[-1]] = shape[-1] - idxs[-1] + + # Move axes back to original positions + unique = np.moveaxis(unique, -1, axis) + count_matrix = np.moveaxis(count_matrix, -1, axis) + + return unique, count_matrix diff --git a/docs/contrib/index.html b/docs/contrib/index.html index a61c5e7..0db731e 100644 --- a/docs/contrib/index.html +++ b/docs/contrib/index.html @@ -33,7 +33,7 @@

Sub-modules

contrib.semseg
-

Overview …

+

This sub-module reproduces the results from Section 4.2 of our paper …

diff --git a/docs/contrib/semseg/config.html b/docs/contrib/semseg/config.html index 1350690..91cdac9 100644 --- a/docs/contrib/semseg/config.html +++ b/docs/contrib/semseg/config.html @@ -147,6 +147,79 @@

Class variables

+
+class Visuals +(sae_ckpt: str = './checkpoints/sae.pt', acts: DataLoad = <factory>, imgs: Ade20kDataset = <factory>, batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32) +
+
+

Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)

+
+ +Expand source code + +
@beartype.beartype
+@dataclasses.dataclass(frozen=True)
+class Visuals:
+    sae_ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
+    """Path to the sae.pt file."""
+    acts: saev.config.DataLoad = dataclasses.field(default_factory=saev.config.DataLoad)
+    """Configuration for the saved ADE20K training ViT activations."""
+    imgs: saev.config.Ade20kDataset = dataclasses.field(
+        default_factory=lambda: saev.config.Ade20kDataset(split="training")
+    )
+    """Configuration for the ADE20K training dataset."""
+    batch_size: int = 128
+    """Batch size for calculating F1 scores."""
+    n_workers: int = 32
+    """Number of dataloader workers."""
+    label_threshold: float = 0.9
+    device: str = "cuda"
+    "Hardware for SAE inference." ""
+    ade20k_cls: int = 29
+    """ADE20K class to probe for."""
+    k: int = 32
+    """Top K features to save."""
+
+

Class variables

+
+
var actsDataLoad
+
+

Configuration for the saved ADE20K training ViT activations.

+
+
var ade20k_cls : int
+
+

ADE20K class to probe for.

+
+
var batch_size : int
+
+

Batch size for calculating F1 scores.

+
+
var device : str
+
+

Hardware for SAE inference.

+
+
var imgsAde20kDataset
+
+

Configuration for the ADE20K training dataset.

+
+
var k : int
+
+

Top K features to save.

+
+
var label_threshold : float
+
+
+
+
var n_workers : int
+
+

Number of dataloader workers.

+
+
var sae_ckpt : str
+
+

Path to the sae.pt file.

+
+
+
@@ -185,6 +258,20 @@

weight_decay +
  • +

    Visuals

    + +
  • diff --git a/docs/contrib/semseg/index.html b/docs/contrib/semseg/index.html index eff3643..bdafbd5 100644 --- a/docs/contrib/semseg/index.html +++ b/docs/contrib/semseg/index.html @@ -5,7 +5,7 @@ contrib.semseg API documentation - + @@ -27,6 +27,7 @@

    Module contrib.semseg

    +

    This sub-module reproduces the results from Section 4.2 of our paper.

    Overview

    As an overview:

      @@ -66,6 +67,10 @@

      Sub-modules

      +
      contrib.semseg.visuals
      +
      +

      Propose features for manual verification.

      +
    @@ -98,6 +103,7 @@

    Sub-modules

  • contrib.semseg.dashboard2
  • contrib.semseg.interactive
  • contrib.semseg.training
  • +
  • contrib.semseg.visuals
  • diff --git a/docs/contrib/semseg/visuals.html b/docs/contrib/semseg/visuals.html new file mode 100644 index 0000000..47f3ef6 --- /dev/null +++ b/docs/contrib/semseg/visuals.html @@ -0,0 +1,90 @@ + + + + + + +contrib.semseg.visuals API documentation + + + + + + + + + + + + + +
    +
    +
    +

    Module contrib.semseg.visuals

    +
    +
    +

    Propose features for manual verification.

    +
    +
    +
    +
    +
    +
    +

    Functions

    +
    +
    +def axis_unique(a: jaxtyping.Shaped[ndarray, '*axes'], axis: int = -1, return_counts: bool = True, *, null_value: int = -1) ‑> jaxtyping.Shaped[ndarray, '*axes'] | tuple[jaxtyping.Shaped[ndarray, '*axes'], jaxtyping.Int[ndarray, '*axes']] +
    +
    +

    Calculate unique values and their counts along any axis of a matrix.

    +

    Arguments

    +

    a: Input array +axis: The axis along which to find unique values. +return_counts: If true, also return the count of each unique value

    +

    Returns

    +
    +
    unique
    +
    Array of unique values, with zeros replacing duplicates
    +
    counts
    +
    (optional) Count of each unique value (only if return_counts=True)
    +
    +
    +
    +def main(cfg: Visuals) +
    +
    +
    +
    +
    +
    +
    +
    +
    + +
    + + + diff --git a/docs/index.html b/docs/index.html index c72fc8f..6b33c65 100644 --- a/docs/index.html +++ b/docs/index.html @@ -35,8 +35,8 @@

    Package Docs

    Package to train SAEs for vision models.

    -
    probing
    -

    Package for probing for individual features in trained SAEs.

    +
    contrib
    +

    Individual sub-packages not related to the core package.

    faithfulness
    diff --git a/docs/llms.txt b/docs/llms.txt index 59d52f8..2daaa4b 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1401,6 +1401,8 @@ Sub-modules Module contrib.semseg ===================== +This sub-module reproduces the results from Section 4.2 of our paper. + # Overview As an overview: @@ -1429,6 +1431,7 @@ Sub-modules * contrib.semseg.dashboard2 * contrib.semseg.interactive * contrib.semseg.training +* contrib.semseg.visuals Module contrib.semseg.config ============================ @@ -1486,6 +1489,38 @@ Classes `weight_decay: float` : Weight decay for AdamW. +`Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32)` +: Visuals(sae_ckpt: str = './checkpoints/sae.pt', acts: saev.config.DataLoad = , imgs: saev.config.Ade20kDataset = , batch_size: int = 128, n_workers: int = 32, label_threshold: float = 0.9, device: str = 'cuda', ade20k_cls: int = 29, k: int = 32) + + ### Class variables + + `acts: saev.config.DataLoad` + : Configuration for the saved ADE20K training ViT activations. + + `ade20k_cls: int` + : ADE20K class to probe for. + + `batch_size: int` + : Batch size for calculating F1 scores. + + `device: str` + : Hardware for SAE inference. + + `imgs: saev.config.Ade20kDataset` + : Configuration for the ADE20K training dataset. + + `k: int` + : Top K features to save. + + `label_threshold: float` + : + + `n_workers: int` + : Number of dataloader workers. + + `sae_ckpt: str` + : Path to the sae.pt file. + Module contrib.semseg.dashboard =============================== @@ -1579,4 +1614,26 @@ Classes ### Instance variables `d_vit: int` - : \ No newline at end of file + : + +Module contrib.semseg.visuals +============================= +Propose features for manual verification. + +Functions +--------- + +`axis_unique(a: jaxtyping.Shaped[ndarray, '*axes'], axis: int = -1, return_counts: bool = True, *, null_value: int = -1) ‑> jaxtyping.Shaped[ndarray, '*axes'] | tuple[jaxtyping.Shaped[ndarray, '*axes'], jaxtyping.Int[ndarray, '*axes']]` +: Calculate unique values and their counts along any axis of a matrix. + + Arguments: + a: Input array + axis: The axis along which to find unique values. + return_counts: If true, also return the count of each unique value + + Returns: + unique: Array of unique values, with zeros replacing duplicates + counts: (optional) Count of each unique value (only if return_counts=True) + +`main(cfg: contrib.semseg.config.Visuals)` +: \ No newline at end of file