diff --git a/contrib/faithfulness/README.md b/contrib/faithfulness/README.md new file mode 100644 index 0000000..cf41ca5 --- /dev/null +++ b/contrib/faithfulness/README.md @@ -0,0 +1,22 @@ +# Faithfulness + +This module demonstrates that SAE features are faithful and that the underlying vision model does in fact depend on the features to make its predictions. + +It demonstrates this through an interactive dashboard and through larger-scale quantitative experiments. + +## Dashboard + +First, record activations for the ADE20K dataset. + +```sh +uv run python -m saev activations \ + --model-group clip \ + --model-ckpt ViT-B-16/openai \ + --d-vit 768 \ + --n-patches-per-img 196 \ + --layers -2 \ + --dump-to /local/scratch/$USER/cache/saev \ + --n-patches-per-shard 2_4000_000 \ + data:ade20k-dataset \ + --data.root /research/nfs_su_809/workspace/stevens.994/datasets/ade20k/images +``` diff --git a/faithfulness/dashboard.py b/faithfulness/dashboard.py index 18f207e..c29dcc8 100644 --- a/faithfulness/dashboard.py +++ b/faithfulness/dashboard.py @@ -36,6 +36,7 @@ def __(): import saev.activations import saev.config import saev.helpers + import saev.nn return ( Image, ImageDraw, @@ -87,8 +88,8 @@ def __(saev): shard_root="/local/scratch/stevens.994/cache/saev/e20bbda1b6b011896dc6f49a698597a7ec000390d73cd7197b0fb243a1e13273/", patches="patches", layer=-2, - scale_norm=True, - scale_mean=True, + scale_norm=False, + scale_mean=False, ) ) @@ -96,8 +97,6 @@ def __(saev): saev.config.Ade20kDataset( root="/research/nfs_su_809/workspace/stevens.994/datasets/ade20k" ), - # transform=make_img_transform(), - # seg_transform=make_seg_transform(), ) return act_dataset, img_dataset @@ -162,37 +161,279 @@ def __(): @app.cell def __(act_dataset, cls_lookup, df, np, obj_classes, pl, torch): - activations, labels, i_ims, i_ps = [], [], [], [] + activations, labels, i_ims, i_ps, i_acts = [], [], [], [], [] for obj_cls in obj_classes: for i_act, i_im, i_p, obj_cls in ( - df.filter(pl.col("obj_cls") == obj_cls).sample(50).iter_rows() + df.filter(pl.col("obj_cls") == obj_cls).sample(50, seed=42).iter_rows() ): activations.append(act_dataset[i_act].vit_acts) labels.append(cls_lookup[obj_cls]) i_ims.append(i_im) i_ps.append(i_p) + i_acts.append(i_act) activations = torch.stack(activations).numpy() labels = np.array(labels) i_im = np.array(i_ims) i_p = np.array(i_ps) + i_act = np.array(i_acts) activations.shape - return activations, i_act, i_im, i_ims, i_p, i_ps, labels, obj_cls + return ( + activations, + i_act, + i_acts, + i_im, + i_ims, + i_p, + i_ps, + labels, + obj_cls, + ) @app.cell -def __(activations, sklearn): +def __(activations, sklearn, without_outliers): pca = sklearn.decomposition.IncrementalPCA(n_components=2) - pca.fit(activations) - x_r = pca.transform(activations) + pca.fit(activations[without_outliers]) + x_r = pca.transform(activations[without_outliers]) return pca, x_r @app.cell -def __(alt, directions, i_im, i_p, labels, mo, np, pca, pl, sliders, x_r): +def __(np, x_r): + print(np.nonzero(x_r[:, 0] < 10)[0].tolist()) + return + + +@app.cell +def __(np): + without_outliers = np.arange(200).tolist() + without_outliers = [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 134, + 135, + 136, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + ] + len(without_outliers) + return (without_outliers,) + + +@app.cell +def __( + alt, + directions, + i_im, + i_p, + labels, + mo, + np, + pca, + pl, + sliders, + without_outliers, + x_r, +): x_shift, y_shift = ( - pca.transform((np.array(sliders.value) @ directions).reshape(1, -1)) + ( + pca.transform( + (np.array(sliders.value) @ directions.detach().numpy()).reshape(1, -1) + ) + - pca.transform(np.zeros((1, 768))) + ) .reshape(-1) .astype(np.float32) ) @@ -202,19 +443,21 @@ def __(alt, directions, i_im, i_p, labels, mo, np, pca, pl, sliders, x_r): pl.concat( ( pl.from_numpy(x_r, ("x", "y")), - pl.from_numpy(i_im, ("i_im",)), - pl.from_numpy(i_p, ("i_p",)), - pl.from_numpy(labels, ("label",)), + pl.from_numpy(i_im[without_outliers], ("i_im",)), + pl.from_numpy(i_p[without_outliers], ("i_p",)), + pl.from_numpy(labels[without_outliers], ("label",)), + pl.from_numpy(np.array(without_outliers), ("example_index",)), ), how="horizontal", ).vstack( pl.DataFrame( { - "x": 0 + x_shift, - "y": 0 + y_shift, - "i_im": 0, - "i_p": 0, - "label": "manipulated", + "x": x_r[141, 0] + x_shift, + "y": x_r[141, 1] + y_shift, + "i_im": i_im[141].item(), + "i_p": i_p[141].item(), + "label": f"{labels[141].item()} (manipulated)", + "example_index": 141, } ) ) @@ -223,7 +466,7 @@ def __(alt, directions, i_im, i_p, labels, mo, np, pca, pl, sliders, x_r): .encode( x=alt.X("x"), y=alt.Y("y"), - tooltip=["i_im"], + tooltip=["example_index", "i_im"], color="label:N", shape="label:N", ) @@ -253,20 +496,32 @@ def __(chart, highlight_patches, img_dataset, mo): @app.cell -def __(mo, np): +def __(): + features = {"rug1": 8541, "rug2": 5818, "window1": 8177} + return (features,) + + +@app.cell +def __(features, mo, sae): # Instead of random unit-norm directions, we should be using the sparse autoencoder to choose the directions. # Specifically, we can get f_x for the manipulated patch, then pick the dimensions that have maximal value. # We can pick out the columns of W_dec and move the patch in those directions. - directions = np.random.default_rng(seed=3).random((2, 768)) - directions /= np.linalg.norm(directions, axis=1, keepdims=True) + + # _, f = f_x.topk(10) + direction_names = list(features.keys()) + + directions = sae.W_dec[[features[name] for name in direction_names]] + # directions = np.random.default_rng(seed=3).random((2, 768)) + # directions /= np.linalg.norm(directions, axis=1, keepdims=True) sliders = mo.ui.array( [ - mo.ui.slider(-500, 500, step=10.0, label=f"Direction {i+1}", value=0) - for i in range(2) + mo.ui.slider(-50, 50, step=3.0, label=f"Direction '{name}' ({features[name]})", value=0) + for name in direction_names ] ) - return directions, sliders + # " ".join([str(i) for i in f.squeeze().tolist()]) + return direction_names, directions, sliders @app.cell @@ -411,5 +666,36 @@ def make_seg_transform(): ) +@app.cell +def __(saev): + sae = saev.nn.load("/home/stevens.994/projects/saev-live/checkpoints/ercgckr1/sae.pt") + print(sae) + return (sae,) + + +@app.cell +def __(activations, sae, torch): + with torch.no_grad(): + x_hat, f_x, _ = sae(torch.from_numpy(activations[101:102])) + return f_x, x_hat + + +@app.cell +def __(activations, torch, x_hat): + (x_hat - torch.from_numpy(activations[101:102])).pow(2).mean() + return + + +@app.cell +def __(f_x): + f_x.topk(5) + return + + +@app.cell +def __(): + return + + if __name__ == "__main__": app.run() diff --git a/pyproject.toml b/pyproject.toml index d5f3ea7..2c6ad70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "open-clip-torch>=2.28.0", "pillow>=11.0.0", "polars>=1.12.0", + "scikit-learn>=1.5.2", "submitit>=1.5.2", "torch>=2.5.0", "tqdm>=4.66.5", diff --git a/saev/activations.py b/saev/activations.py index 3753244..4e1b69f 100644 --- a/saev/activations.py +++ b/saev/activations.py @@ -9,22 +9,22 @@ 2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations. """ +import csv import dataclasses import hashlib import json import logging import math import os -import shutil import typing import beartype import numpy as np import torch import torchvision.datasets -import wids from jaxtyping import Float, Int, jaxtyped from torch import Tensor +from PIL import Image from . import config, helpers @@ -32,6 +32,11 @@ logging.basicConfig(level=logging.INFO, format=log_format) +####################### +# VISION TRANSFORMERS # +####################### + + @jaxtyped(typechecker=beartype.beartype) class VitRecorder(torch.nn.Module): cfg: config.Activations @@ -47,7 +52,9 @@ def __init__( self.patches = patches self._storage = None self._i = 0 - self.logger = logging.getLogger(f"recorder({cfg.model_org}:{cfg.model_ckpt})") + self.logger = logging.getLogger( + f"recorder({cfg.model_family}:{cfg.model_ckpt})" + ) def register(self, modules: list[torch.nn.Module]): for i in self.cfg.layers: @@ -97,7 +104,7 @@ class Clip(torch.nn.Module): def __init__(self, cfg: config.Activations): super().__init__() - assert cfg.model_org == "clip" + assert cfg.model_family == "clip" import open_clip @@ -132,7 +139,7 @@ def forward(self, batch: Float[Tensor, "batch 3 width height"]): class Siglip(torch.nn.Module): def __init__(self, cfg: config.Activations): super().__init__() - assert cfg.model_org == "siglip" + assert cfg.model_family == "siglip" import open_clip @@ -163,47 +170,12 @@ def forward(self, batch: Float[Tensor, "batch 3 width height"]): return result, self.recorder.activations -@jaxtyped(typechecker=beartype.beartype) -class TimmVit(torch.nn.Module): - def __init__(self, cfg: config.Activations): - super().__init__() - assert cfg.model_org == "timm" - import timm - - err_msg = "You are trying to load a non-ViT checkpoint; the `img_encode()` method assumes `model.forward_features()` will return features with shape (batch, n_patches, dim) which is not true for non-ViT checkpoints." - assert "vit" in cfg.model_ckpt, err_msg - self.model = timm.create_model(cfg.model_ckpt, pretrained=True) - - data_cfg = timm.data.resolve_data_config(self.model.pretrained_cfg) - self._img_transform = timm.data.create_transform(**data_cfg, is_training=False) - - self.recorder = VitRecorder(cfg).register(self.model.blocks) - - def make_img_transform(self): - return self._img_transform - - def forward(self, batch: Float[Tensor, "batch 3 width height"]): - self.recorder.reset() - - patches = self.model.forward_features(batch) - # Use [CLS] token if it exists for img representation, otherwise do a maxpool - if self.model.num_prefix_tokens > 0: - img = patches[:, 0, ...] - else: - img = patches.max(axis=1).values - - # Return only the [CLS] token and the patches. - patches = patches[:, self.model.num_prefix_tokens :, ...] - - return torch.cat((img[:, None, :], patches), axis=1), self.recorder.activations - - @jaxtyped(typechecker=beartype.beartype) class DinoV2(torch.nn.Module): def __init__(self, cfg: config.Activations): super().__init__() - assert cfg.model_org == "dinov2" + assert cfg.model_family == "dinov2" self.model = torch.hub.load("facebookresearch/dinov2", cfg.model_ckpt) @@ -219,6 +191,7 @@ def make_img_transform(self): from torchvision.transforms import v2 return v2.Compose([ + # TODO: I bet this should be 256, 256, which is causing localization issues in non-square images. v2.Resize(size=256), v2.CenterCrop(size=(224, 224)), v2.ToImage(), @@ -239,16 +212,19 @@ def forward(self, batch: Float[Tensor, "batch 3 width height"]): @beartype.beartype def make_vit(cfg: config.Activations): - if cfg.model_org == "timm": - return TimmVit(cfg) - elif cfg.model_org == "clip": + if cfg.model_family == "clip": return Clip(cfg) - elif cfg.model_org == "siglip": + elif cfg.model_family == "siglip": return Siglip(cfg) - elif cfg.model_org == "dinov2": + elif cfg.model_family == "dinov2": return DinoV2(cfg) else: - typing.assert_never(cfg.model_org) + typing.assert_never(cfg.model_family) + + +############### +# ACTIVATIONS # +############### @jaxtyped(typechecker=beartype.beartype) @@ -487,6 +463,11 @@ def __len__(self) -> int: typing.assert_never((self.cfg.patches, self.cfg.layer)) +########## +# IMAGES # +########## + + @beartype.beartype def setup(cfg: config.Activations): """ @@ -494,12 +475,10 @@ def setup(cfg: config.Activations): """ if isinstance(cfg.data, config.ImagenetDataset): setup_imagenet(cfg) - elif isinstance(cfg.data, config.TreeOfLifeDataset): - setup_tol(cfg) - elif isinstance(cfg.data, config.LaionDataset): - setup_laion(cfg) elif isinstance(cfg.data, config.ImageFolderDataset): setup_imagefolder(cfg) + elif isinstance(cfg.data, config.Ade20kDataset): + setup_ade20k(cfg) else: typing.assert_never(cfg.data) @@ -510,116 +489,19 @@ def setup_imagenet(cfg: config.Activations): @beartype.beartype -def setup_tol(cfg: config.Activations): - assert isinstance(cfg.data, config.TreeOfLifeDataset) +def setup_imagefolder(cfg: config.Activations): + assert isinstance(cfg.data, config.ImageFolderDataset) + breakpoint() @beartype.beartype -def setup_laion(cfg: config.Activations): - """ - Do setup for LAION dataloader. - """ - assert isinstance(cfg.data, config.LaionDataset) - - import datasets - import img2dataset - import submitit - - logger = logging.getLogger("laion") - - # 1. Download cfg.data.n_imgs data urls. +def setup_ade20k(cfg: config.Activations): + assert isinstance(cfg.data, config.Ade20kDataset) - # Check if URL list exists. - n_urls = 0 - if os.path.isfile(cfg.data.url_list_filepath): + # url = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" + # breakpoint() - def blocks(files, size=65536): - while True: - b = files.read(size) - if not b: - break - yield b - - with open(cfg.data.url_list_filepath, "r") as fd: - n_urls = sum(bl.count("\n") for bl in blocks(fd)) - - # We use -1 just in case there's something wrong with our n_urls count. - dumped_urls = n_urls >= cfg.data.n_imgs - 1 - - # If we don't have all the image urls written, need to dump to a file. - if not dumped_urls: - logger.info("Dumping URLs to '%s'.", cfg.data.url_list_filepath) - - if os.path.isfile(cfg.data.url_list_filepath): - logger.warning( - "Overwriting existing list of %d URLs because we want %d URLs.", - n_urls, - cfg.data.n_imgs, - ) - - dataset = ( - datasets.load_dataset(cfg.data.name, streaming=True, split="train") - .shuffle(cfg.seed) - .filter( - lambda example: example["status"] == "success" - and example["height"] >= 256 - and example["width"] >= 256 - ) - .take(cfg.data.n_imgs) - ) - - with open(cfg.data.url_list_filepath, "w") as fd: - for example in helpers.progress( - dataset, every=5_000, desc="Writing URLs", total=cfg.data.n_imgs - ): - fd.write(f'{{"url": "{example["url"]}", "key": "{example["key"]}"}}\n') - - # 2. Download the images to a webdatset format using img2dataset - # TODO: check whether images are downloaded. Read all the _stats.json files and see if we have all 10M. - imgs_downloaded = False - - if not imgs_downloaded: - - def download(n_processes: int, n_threads: int): - assert isinstance(cfg.data, config.LaionDataset) - - if os.path.exists(cfg.data.tar_dir): - shutil.rmtree(cfg.data.tar_dir) - - img2dataset.download( - url_list=cfg.data.url_list_filepath, - input_format="jsonl", - image_size=256, - output_folder=cfg.data.tar_dir, - processes_count=n_processes, - thread_count=n_threads, - resize_mode="keep_ratio", - encode_quality=100, - encode_format="webp", - output_format="webdataset", - oom_shard_count=6, - ignore_ssl=not cfg.ssl, - ) - - if cfg.slurm: - executor = submitit.SlurmExecutor(folder=cfg.log_to) - executor.update_parameters( - time=12 * 60, - partition="cpuonly", - cpus_per_task=64, - stderr_to_stdout=True, - account=cfg.slurm_acct, - ) - job = executor.submit(download, 64, 256) - job.result() - else: - download(8, 32) - - -@beartype.beartype -def setup_imagefolder(cfg: config.Activations): - assert isinstance(cfg.data, config.ImageFolderDataset) - breakpoint() + # 1. Check @beartype.beartype @@ -635,9 +517,11 @@ def get_dataset(cfg: config.DatasetConfig, *, transform): A dataset that has dictionaries with `'image'`, `'index'`, `'target'`, and `'label'` keys containing examples. """ if isinstance(cfg, config.ImagenetDataset): - return TransformedImagenet(cfg, transform) + return TransformedImagenet(cfg, transform=transform) + elif isinstance(cfg, config.Ade20kDataset): + return TransformedAde20k(cfg, transform=transform) else: - typing.assert_never(cfg.data) + typing.assert_never(cfg) @beartype.beartype @@ -652,12 +536,11 @@ def get_dataloader(cfg: config.Activations, preprocess): Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches. """ - if isinstance(cfg.data, (config.ImagenetDataset, config.ImageFolderDataset)): + if isinstance( + cfg.data, + (config.ImagenetDataset, config.ImageFolderDataset, config.Ade20kDataset), + ): dataloader = get_default_dataloader(cfg, preprocess) - elif isinstance(cfg.data, config.TreeOfLifeDataset): - dataloader = get_tol_dataloader(cfg, preprocess) - elif isinstance(cfg.data, config.LaionDataset): - dataloader = get_laion_dataloader(cfg, preprocess) else: typing.assert_never(cfg.data) @@ -678,7 +561,7 @@ def get_default_dataloader( Returns: A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches, `'index'` keys containing original dataset indices and `'label'` keys containing label batches. """ - dataset = get_dataset(cfg, transform) + dataset = get_dataset(cfg.data, transform=transform) dataloader = torch.utils.data.DataLoader( dataset=dataset, @@ -693,61 +576,8 @@ def get_default_dataloader( @beartype.beartype -def get_laion_dataloader( - cfg: config.Activations, preprocess -) -> torch.utils.data.DataLoader: - """ - Get a dataloader for a subset of the LAION datasets. - - This requires several steps: - - 1. Download list of image URLs - 2. Use img2dataset to download these images to webdataset format. - 3. Create a dataloader from these webdataset tar files. - - So that we don't have to redo any of these steps, we check on the existence of various files to check if this stuff is done already. - """ - # 3. Create a webdataset loader over these images. - # TODO: somehow we need to know which images are in the dataloader. Like, a way to actually go back to the original image. The HF dataset has a "key" column that is likely unique. - breakpoint() - - -@beartype.beartype -def get_tol_dataloader( - cfg: config.Activations, preprocess -) -> torch.utils.data.DataLoader: - """ - Get a dataloader for the TreeOfLife-10M dataset. - - Currently does not include a true index or label in the loaded examples. - - Args: - cfg: Config for loading activations. - preprocess: Image transform to be applied to each image. - - Returns: - A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches. - """ - assert isinstance(cfg.data, config.TreeOfLifeDataset) - - def transform(sample: dict): - return {"image": preprocess(sample[".jpg"]), "index": sample["__key__"]} - - dataset = wids.ShardListDataset(cfg.data.metadata).add_transform(transform) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=cfg.vit_batch_size, - shuffle=False, - num_workers=cfg.n_workers, - persistent_workers=False, - ) - - return dataloader - - class TransformedImagenet(torch.utils.data.Dataset): - def __init__(self, cfg: config.ImagenetDataset, transform): + def __init__(self, cfg: config.ImagenetDataset, *, transform=None): import datasets self.hf_dataset = datasets.load_dataset( @@ -773,6 +603,7 @@ def __len__(self) -> int: return len(self.hf_dataset) +@beartype.beartype class TransformedImageFolder(torchvision.datasets.ImageFolder): def __getitem__(self, index: int) -> dict[str, object]: """ @@ -794,7 +625,100 @@ def __getitem__(self, index: int) -> dict[str, object]: @beartype.beartype -def dump(cfg: config.Activations): +class TransformedAde20k(torch.utils.data.Dataset): + class Sample(typing.TypedDict): + img_path: str + seg_path: str + split: str + label: str + target: int + + samples: list[Sample] + + def __init__( + self, cfg: config.Ade20kDataset, *, transform=None, seg_transform=None + ): + self.logger = logging.getLogger("ade20k") + self.cfg = cfg + self.img_dir = os.path.join(cfg.root, "images") + self.seg_dir = os.path.join(cfg.root, "annotations") + self.transform = transform + self.seg_transform = seg_transform + + # Check that we have the right path. + for subdir in ("images", "annotations"): + if not os.path.isdir(os.path.join(cfg.root, subdir)): + # Something is missing. + if os.path.realpath(cfg.root).endswith(subdir): + self.logger.warning( + "The ADE20K root should contain 'images/' and 'annotations/' directories." + ) + raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.") + + _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir) + split_lookup: dict[int, str] = { + value: key for key, value in split_mapping.items() + } + self.loader = torchvision.datasets.folder.default_loader + + # Load all the image paths. + imgs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset( + self.img_dir, + split_mapping, + extensions=torchvision.datasets.folder.IMG_EXTENSIONS, + ) + + segs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset( + self.seg_dir, + split_mapping, + extensions=torchvision.datasets.folder.IMG_EXTENSIONS, + ) + + # Load all the targets, classes and mappings + with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd: + img_labels: list[str] = [line.split()[1] for line in fd.readlines()] + + label_set = sorted(set(img_labels)) + label_to_idx = {label: i for i, label in enumerate(label_set)} + + self.samples = [ + self.Sample( + img_path=img_path, + seg_path=seg_path, + split=split_lookup[split], + label=label, + target=label_to_idx[label], + ) + for (img_path, split), (seg_path, _), label in zip(imgs, segs, img_labels) + ] + + def __getitem__(self, index: int) -> dict[str, object]: + # Make a copy + sample = dict(**self.samples[index]) + + sample["image"] = self.loader(sample.pop("img_path")) + if self.transform is not None: + sample["image"] = self.transform(sample["image"]) + + sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L") + if self.seg_transform is not None: + sample["segmentation"] = self.seg_transform(sample["segmentation"]) + + sample["index"] = index + + return sample + + def __len__(self) -> int: + return len(self.samples) + + +######## +# MAIN # +######## + + +@beartype.beartype +def main(cfg: config.Activations): """ Args: cfg: Config for activations. @@ -967,7 +891,7 @@ def next_shard(self) -> None: @beartype.beartype @dataclasses.dataclass(frozen=True) class Metadata: - model_org: str + model_family: str model_ckpt: str layers: tuple[int, ...] n_patches_per_img: int @@ -981,7 +905,7 @@ class Metadata: @classmethod def from_cfg(cls, cfg: config.Activations) -> "Metadata": return cls( - cfg.model_org, + cfg.model_family, cfg.model_ckpt, tuple(cfg.layers), cfg.n_patches_per_img, diff --git a/saev/config.py b/saev/config.py index e71237e..649932b 100644 --- a/saev/config.py +++ b/saev/config.py @@ -58,7 +58,21 @@ def n_imgs(self) -> int: return n -DatasetConfig = ImagenetDataset | ImageFolderDataset +@beartype.beartype +@dataclasses.dataclass(frozen=True) +class Ade20kDataset: + """ """ + + root: str = os.path.join(".", "data", "split") + """Where the class folders with images are stored.""" + + @property + def n_imgs(self) -> int: + with open(os.path.join(self.root, "sceneCategories.txt")) as fd: + return len(fd.read().split("\n")) + + +DatasetConfig = ImagenetDataset | ImageFolderDataset | Ade20kDataset @beartype.beartype @@ -72,8 +86,8 @@ class Activations: """Which dataset to use.""" dump_to: str = os.path.join(".", "shards") """Where to write shards.""" - model_org: typing.Literal["clip", "siglip", "timm", "dinov2"] = "clip" - """Where to load models from.""" + model_family: typing.Literal["clip", "siglip", "dinov2"] = "clip" + """Which model family.""" model_ckpt: str = "ViT-L-14/openai" """Specific model checkpoint.""" vit_batch_size: int = 1024 diff --git a/saev/nn.py b/saev/nn.py index cafdd1e..421b1cb 100644 --- a/saev/nn.py +++ b/saev/nn.py @@ -18,6 +18,8 @@ class Loss(typing.NamedTuple): + """The composite loss terms for an autoencoder training batch.""" + mse: Float[Tensor, ""] """Reconstruction loss (mean squared error).""" sparsity: Float[Tensor, ""] @@ -61,8 +63,15 @@ def __init__(self, cfg: config.SparseAutoencoder): self.logger = logging.getLogger(f"sae(seed={cfg.seed})") def forward( - self, x: Float[Tensor, "batch d_model"], dead_neuron_mask: None = None + self, x: Float[Tensor, "batch d_model"] ) -> tuple[Float[Tensor, "batch d_model"], Float[Tensor, "batch d_sae"], Loss]: + """ + Given x, calculates the reconstructed x_hat, the intermediate activations f_x, and the loss. + + Arguments: + x: a batch of ViT activations. + """ + # Remove encoder bias as per Anthropic h_pre = ( einops.einsum( @@ -108,7 +117,9 @@ def init_b_dec(self, vit_acts: Float[Tensor, "n d_vit"]): @torch.no_grad() def normalize_w_dec(self): - # Make sure the W_dec is still unit-norm + """ + Set W_dec to unit-norm columns. + """ if self.cfg.normalize_w_dec: self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) diff --git a/saev/visuals.py b/saev/visuals.py index 925fdfd..989b965 100644 --- a/saev/visuals.py +++ b/saev/visuals.py @@ -1,5 +1,4 @@ """ - There is some important notation used only in this file to dramatically shorten variable names. Variables suffixed with `_im` refer to entire images, and variables suffixed with `_p` refer to patches. @@ -76,7 +75,17 @@ def get_new_topk( k: int, ) -> tuple[Float[Tensor, "d_sae k"], Int[Tensor, "d_sae k"]]: """ - .. todo:: document this function. + Picks out the new top k values among val1 and val2. Also keeps track of i1 and i2, then indices of the values in the original dataset. + + Args: + val1: top k original SAE values. + i1: the patch indices of those original top k values. + val2: top k incoming SAE values. + i2: the patch indices of those incoming top k values. + k: k. + + Returns: + The new top k values and their patch indices. """ all_val = torch.cat([val1, val2], dim=1) new_values, top_i = torch.topk(all_val, k=k, dim=1) @@ -399,7 +408,7 @@ def main(cfg: config.Visuals): & (torch.log10(mean_values) < max_log_value) ) - neuron_i = torch.arange(d_sae)[mask.cpu()].tolist() + neuron_i = cfg.include_latents + torch.arange(d_sae)[mask.cpu()].tolist() for i in tqdm.tqdm(neuron_i, desc="saving visuals"): neuron_dir = os.path.join(cfg.root, "neurons", str(i))