Namespace contrib
+Sub-modules
+-
+
contrib.semseg
+- + + +
diff --git a/contrib/semseg/__main__.py b/contrib/semseg/__main__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/contrib/semseg/__main__.py @@ -0,0 +1 @@ + diff --git a/contrib/semseg/config.py b/contrib/semseg/config.py index 1e79d15..e7cd698 100644 --- a/contrib/semseg/config.py +++ b/contrib/semseg/config.py @@ -14,11 +14,11 @@ class Train: """Linear layer learning rate.""" weight_decay: float = 1e-3 """Weight decay for AdamW.""" - n_epochs: int = 10 + n_epochs: int = 100 """Number of training epochs for linear layer.""" - batch_size: int = 32 + batch_size: int = 1024 """Training batch size for linear layer.""" - n_workers: int = 8 + n_workers: int = 32 """Number of dataloader workers.""" train_acts: saev.config.DataLoad = dataclasses.field( default_factory=saev.config.DataLoad diff --git a/contrib/semseg/dashboard.py b/contrib/semseg/dashboard.py index 0c6c2a2..3afeee0 100644 --- a/contrib/semseg/dashboard.py +++ b/contrib/semseg/dashboard.py @@ -15,23 +15,21 @@ def __(): if pkg_root not in sys.path: sys.path.append(pkg_root) - import os import csv - import einops - - import marimo as mo - import sklearn.decomposition + import os - import numpy as np - import datasets + import altair as alt import beartype - from PIL import Image, ImageDraw - from jaxtyping import jaxtyped, Int - - import torch + import datasets + import einops + import marimo as mo import matplotlib.pyplot as plt + import numpy as np import polars as pl - import altair as alt + import sklearn.decomposition + import torch + from jaxtyping import Int, jaxtyped + from PIL import Image, ImageDraw import saev.activations import saev.config diff --git a/contrib/semseg/dashboard2.py b/contrib/semseg/dashboard2.py index 009b80f..2cc908a 100644 --- a/contrib/semseg/dashboard2.py +++ b/contrib/semseg/dashboard2.py @@ -14,17 +14,17 @@ def __(): import random - import marimo as mo import beartype - + import contrib.semantic_seg.training + import marimo as mo import numpy as np import torch - from torchvision.transforms import v2 + from jaxtyping import Int, UInt8, jaxtyped from PIL import Image - from jaxtyping import jaxtyped, UInt8, Int + from torchvision.transforms import v2 - import contrib.semantic_seg.training import saev.config + return ( Image, Int, @@ -45,7 +45,9 @@ def __(): @app.cell def __(contrib): - ckpt_fpath = "/home/stevens.994/projects/saev-live/checkpoints/faithfulness/model.pt" + ckpt_fpath = ( + "/home/stevens.994/projects/saev-live/checkpoints/faithfulness/model.pt" + ) model = contrib.semantic_seg.training.load(ckpt_fpath) model.eval() return ckpt_fpath, model @@ -68,15 +70,12 @@ def __(contrib, saev): @app.cell def __(v2): def make_img_transform(): - return v2.Compose( - [ - v2.Resize(size=224, interpolation=v2.InterpolationMode.NEAREST), - v2.CenterCrop(size=(224, 224)), - # v2.ToImage(), - # v2.ToDtype(torch.uint8), - ] - ) - + return v2.Compose([ + v2.Resize(size=224, interpolation=v2.InterpolationMode.NEAREST), + v2.CenterCrop(size=(224, 224)), + # v2.ToImage(), + # v2.ToDtype(torch.uint8), + ]) img_transform = make_img_transform() return img_transform, make_img_transform @@ -97,7 +96,6 @@ def make_colors(seed: int = 42) -> UInt8[np.ndarray, "n 3"]: colors = np.array(colors, dtype=np.uint8) return colors - @jaxtyped(typechecker=beartype.beartype) def color_map(map: UInt8[np.ndarray, "width height"]) -> Image.Image: colored = np.zeros((224, 224, 3), dtype=np.uint8) @@ -105,6 +103,7 @@ def color_map(map: UInt8[np.ndarray, "width height"]) -> Image.Image: colored[map == i, :] = color return Image.fromarray(colored) + return color_map, make_colors @@ -177,11 +176,17 @@ def mean_iou( if ignore_class is not None: pred_one_hot = torch.cat( - (pred_one_hot[..., :ignore_class], pred_one_hot[..., ignore_class + 1 :]), + ( + pred_one_hot[..., :ignore_class], + pred_one_hot[..., ignore_class + 1 :], + ), axis=-1, ) true_one_hot = torch.cat( - (true_one_hot[..., :ignore_class], true_one_hot[..., ignore_class + 1 :]), + ( + true_one_hot[..., :ignore_class], + true_one_hot[..., ignore_class + 1 :], + ), axis=-1, ) @@ -193,7 +198,6 @@ def mean_iou( # Handle division by zero return ((intersection + eps) / (union + eps)).mean().item() - mean_iou(y_pred, y_true, 151) return (mean_iou,) @@ -281,7 +285,6 @@ def intersect_and_union( return area_intersect, area_union, area_pred_label, area_label - @jaxtyped(typechecker=beartype.beartype) def total_intersect_and_union( results, @@ -321,8 +324,15 @@ def total_intersect_and_union( total_area_pred_label = np.zeros((num_labels,), dtype=np.float64) total_area_label = np.zeros((num_labels,), dtype=np.float64) for result, gt_seg_map in zip(results, gt_seg_maps): - area_intersect, area_union, area_pred_label, area_label = intersect_and_union( - result, gt_seg_map, num_labels, ignore_index, label_map, reduce_labels + area_intersect, area_union, area_pred_label, area_label = ( + intersect_and_union( + result, + gt_seg_map, + num_labels, + ignore_index, + label_map, + reduce_labels, + ) ) total_area_intersect += area_intersect total_area_union += area_union @@ -335,7 +345,6 @@ def total_intersect_and_union( total_area_label, ) - @jaxtyped(typechecker=beartype.beartype) def mean_iou( results, @@ -376,10 +385,13 @@ def mean_iou( - *per_category_iou* (`ndarray` of shape `(num_labels,)`): Per category IoU. """ - total_area_intersect, total_area_union, total_area_pred_label, total_area_label = ( - total_intersect_and_union( - results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels - ) + ( + total_area_intersect, + total_area_union, + total_area_pred_label, + total_area_label, + ) = total_intersect_and_union( + results, gt_seg_maps, num_labels, ignore_index, label_map, reduce_labels ) # compute metrics @@ -396,20 +408,20 @@ def mean_iou( metrics["per_category_accuracy"] = acc if nan_to_num is not None: - metrics = dict( - { - metric: np.nan_to_num(metric_value, nan=nan_to_num) - for metric, metric_value in metrics.items() - } - ) + metrics = dict({ + metric: np.nan_to_num(metric_value, nan=nan_to_num) + for metric, metric_value in metrics.items() + }) return metrics + return intersect_and_union, mean_iou, total_intersect_and_union @app.cell def __(): import tensordict + return (tensordict,) diff --git a/contrib/semseg/training.py b/contrib/semseg/training.py index 8b13789..e7f7866 100644 --- a/contrib/semseg/training.py +++ b/contrib/semseg/training.py @@ -1 +1,40 @@ +import beartype +import torch +import saev.config + +from . import config + +n_classes = 151 + + +@beartype.beartype +def main(cfg: config.Train): + train_dataset = Dataset(cfg.train_acts, cfg.train_imgs) + val_dataset = Dataset(cfg.val_acts, cfg.val_imgs) + + model = torch.nn.Linear(train_dataset.d_vit, n_classes) + optim = torch.optim.AdamW( + model.parameters, lr=cfg.learning_rate, weight_decay=cfg.weight_decay + ) + + for epoch in range(cfg.n_epochs): + model.train() + for batch in train_dataloader: + breakpoint() + + model.eval() + for batch in val_dataloader: + breakpoint() + + +@beartype.beartype +class Dataset(torch.utils.data.Dataset): + def __init__( + self, acts_cfg: saev.config.DataLoad, imgs_cfg: saev.config.Ade20kDataset + ): + breakpoint() + + @property + def d_vit(self) -> int: + breakpoint() diff --git a/docs/contrib/index.html b/docs/contrib/index.html new file mode 100644 index 0000000..35a9e5b --- /dev/null +++ b/docs/contrib/index.html @@ -0,0 +1,64 @@ + + +
+ + + +contrib
contrib.semseg
contrib.semseg.config
+class Train
+(ckpt_path: str = './checkpoints/faithfulness', learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 10, batch_size: int = 32, n_workers: int = 8, train_acts: DataLoad = <factory>, val_acts: DataLoad = <factory>, train_imgs: Ade20kDataset = <factory>, val_imgs: Ade20kDataset = <factory>, log_every: int = 10)
+
Train(ckpt_path: str = './checkpoints/faithfulness', learning_rate: float = 0.0001, weight_decay: float = 0.001, n_epochs: int = 10, batch_size: int = 32, n_workers: int = 8, train_acts: saev.config.DataLoad =
@beartype.beartype
+@dataclasses.dataclass(frozen=True)
+class Train:
+ ckpt_path: str = os.path.join(".", "checkpoints", "faithfulness")
+ learning_rate: float = 1e-4
+ """Linear layer learning rate."""
+ weight_decay: float = 1e-3
+ """Weight decay for AdamW."""
+ n_epochs: int = 10
+ """Number of training epochs for linear layer."""
+ batch_size: int = 32
+ """Training batch size for linear layer."""
+ n_workers: int = 8
+ """Number of dataloader workers."""
+ train_acts: saev.config.DataLoad = dataclasses.field(
+ default_factory=saev.config.DataLoad
+ )
+ """Configuration for the saved ADE20K training ViT activations."""
+ val_acts: saev.config.DataLoad = dataclasses.field(
+ default_factory=saev.config.DataLoad
+ )
+ """Configuration for the saved ADE20K validation ViT activations."""
+ train_imgs: saev.config.Ade20kDataset = dataclasses.field(
+ default_factory=lambda: saev.config.Ade20kDataset(split="training")
+ )
+ """Configuration for the training ADE20K dataset."""
+ val_imgs: saev.config.Ade20kDataset = dataclasses.field(
+ default_factory=lambda: saev.config.Ade20kDataset(split="validation")
+ )
+ """Configuration for the validation ADE20K dataset."""
+ log_every: int = 10
+ """How often to log during training."""
+var batch_size : int
Training batch size for linear layer.
var ckpt_path : str
var learning_rate : float
Linear layer learning rate.
var log_every : int
How often to log during training.
var n_epochs : int
Number of training epochs for linear layer.
var n_workers : int
Number of dataloader workers.
var train_acts : DataLoad
Configuration for the saved ADE20K training ViT activations.
var train_imgs : Ade20kDataset
Configuration for the training ADE20K dataset.
var val_acts : DataLoad
Configuration for the saved ADE20K validation ViT activations.
var val_imgs : Ade20kDataset
Configuration for the validation ADE20K dataset.
var weight_decay : float
Weight decay +for AdamW.
contrib.semseg.dashboard
contrib.semseg.dashboard2
contrib.semseg
contrib.semseg.config
contrib.semseg.dashboard
contrib.semseg.dashboard2
contrib.semseg.training
contrib.semseg.training
Directory to where activations should be dumped/loaded from.
-def get_dataloader(cfg: Activations, preprocess)
+def get_dataloader(cfg: Activations, *, img_transform=None)
Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions.
@@ -65,14 +65,14 @@cfg
preprocess
img_transform
A PyTorch Dataloader that yields dictionaries with 'image'
keys containing image batches.
-def get_dataset(cfg: ImagenetDataset | ImageFolderDataset | Ade20kDataset, *, transform)
+def get_dataset(cfg: ImagenetDataset | ImageFolderDataset | Ade20kDataset, *, img_transform)
Gets the dataset for the current experiment; delegates construction to dataset-specific functions.
@@ -80,14 +80,14 @@cfg
transform
img_transform
A dataset that has dictionaries with 'image'
, 'index'
, 'target'
, and 'label'
keys containing examples.
-def get_default_dataloader(cfg: Activations, transform) ‑> torch.utils.data.dataloader.DataLoader
+def get_default_dataloader(cfg: Activations, *, img_transform: collections.abc.Callable) ‑> torch.utils.data.dataloader.DataLoader
Get a dataloader for a default map-style dataset.
@@ -95,7 +95,7 @@cfg
preprocess
img_transform
-def make_img_transform(model_family: str, model_ckpt: str) ‑> Callable
+def make_img_transform(model_family: str, model_ckpt: str) ‑> collections.abc.Callable
+class Ade20k
+(cfg: Ade20kDataset, *, img_transform: collections.abc.Callable | None = None, seg_transform: collections.abc.Callable | None = <function Ade20k.<lambda>>)
+
An abstract class representing a :class:Dataset
.
All datasets that represent a map from keys to data samples should subclass
+it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a
+data sample for a given key. Subclasses could also optionally overwrite
+:meth:__len__
, which is expected to return the size of the dataset by many
+:class:~torch.utils.data.Sampler
implementations and the default options
+of :class:~torch.utils.data.DataLoader
. Subclasses could also
+optionally implement :meth:__getitems__
, for speedup batched samples
+loading. This method accepts list of indices of samples of batch and returns
+list of samples.
Note
+:class:~torch.utils.data.DataLoader
by default constructs an index
+sampler that yields integral indices.
+To make it work with a map-style
+dataset with non-integral indices/keys, a custom sampler must be provided.
@beartype.beartype
+class Ade20k(torch.utils.data.Dataset):
+ @beartype.beartype
+ @dataclasses.dataclass(frozen=True)
+ class Sample:
+ img_path: str
+ seg_path: str
+ label: str
+ target: int
+
+ samples: list[Sample]
+
+ def __init__(
+ self,
+ cfg: config.Ade20kDataset,
+ *,
+ img_transform: collections.abc.Callable | None = None,
+ seg_transform: collections.abc.Callable | None = lambda x: None,
+ ):
+ self.logger = logging.getLogger("ade20k")
+ self.cfg = cfg
+ self.img_dir = os.path.join(cfg.root, "images")
+ self.seg_dir = os.path.join(cfg.root, "annotations")
+ self.img_transform = img_transform
+ self.seg_transform = seg_transform
+
+ # Check that we have the right path.
+ for subdir in ("images", "annotations"):
+ if not os.path.isdir(os.path.join(cfg.root, subdir)):
+ # Something is missing.
+ if os.path.realpath(cfg.root).endswith(subdir):
+ self.logger.warning(
+ "The ADE20K root should contain 'images/' and 'annotations/' directories."
+ )
+ raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.")
+
+ _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
+ split_lookup: dict[int, str] = {
+ value: key for key, value in split_mapping.items()
+ }
+ self.loader = torchvision.datasets.folder.default_loader
+
+ assert cfg.split in set(split_lookup.values())
+
+ # Load all the image paths.
+ imgs: list[str] = [
+ path
+ for path, s in torchvision.datasets.folder.make_dataset(
+ self.img_dir,
+ split_mapping,
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
+ )
+ if split_lookup[s] == cfg.split
+ ]
+
+ segs: list[str] = [
+ path
+ for path, s in torchvision.datasets.folder.make_dataset(
+ self.seg_dir,
+ split_mapping,
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
+ )
+ if split_lookup[s] == cfg.split
+ ]
+
+ # Load all the targets, classes and mappings
+ with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd:
+ img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
+
+ label_set = sorted(set(img_labels))
+ label_to_idx = {label: i for i, label in enumerate(label_set)}
+
+ self.samples = [
+ self.Sample(img_path, seg_path, label, label_to_idx[label])
+ for img_path, seg_path, label in zip(imgs, segs, img_labels)
+ ]
+
+ def __getitem__(self, index: int) -> dict[str, object]:
+ # Convert to dict.
+ sample = dataclasses.asdict(self.samples[index])
+
+ sample["image"] = self.loader(sample.pop("img_path"))
+ if self.img_transform is not None:
+ image = self.img_transform(sample.pop("image"))
+ if image is not None:
+ sample["image"] = image
+
+ sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
+ if self.seg_transform is not None:
+ segmentation = self.seg_transform(sample.pop("segmentation"))
+ if segmentation is not None:
+ sample["segmentation"] = segmentation
+
+ sample["index"] = index
+
+ return sample
+
+ def __len__(self) -> int:
+ return len(self.samples)
+var Sample
var samples : list[Ade20k.Sample]
class Clip
(cfg: Activations)
@@ -658,6 +801,135 @@ Methods
+class ImageFolder
+(root: Union[str, pathlib.Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = <function default_loader>, is_valid_file: Optional[Callable[[str], bool]] = None, allow_empty: bool = False)
+
A generic data loader where the images are arranged in this way by default: ::
+root/dog/xxx.png
+root/dog/xxy.png
+root/dog/[...]/xxz.png
+
+root/cat/123.png
+root/cat/nsdf3.png
+root/cat/[...]/asd932_.png
+
+This class inherits from :class:~torchvision.datasets.DatasetFolder
so
+the same methods can be overridden to customize the dataset.
pathlib.Path
): Root directory path.transform
: callable
, optionaltransforms.RandomCrop
target_transform
: callable
, optionalloader
: callable
, optionalis_valid_file
: callable
, optionalallow_empty(bool, optional): If True, empty folders are considered to be valid classes. +An error is raised on empty folders if False (default). +Attributes: +classes (list): List of the class names sorted alphabetically. +class_to_idx (dict): Dict with items (class_name, class_index). +imgs (list): List of (image path, class_index) tuples
@beartype.beartype
+class ImageFolder(torchvision.datasets.ImageFolder):
+ def __getitem__(self, index: int) -> dict[str, object]:
+ """
+ Args:
+ index: Index
+
+ Returns:
+ dict with keys 'image', 'index', 'target' and 'label'.
+ """
+ breakpoint()
+ path, target = self.samples[index]
+ sample = self.loader(path)
+ if self.img_transform is not None:
+ sample = self.img_transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return {"image": sample, "target": target, "index": index}
+
+class Imagenet
+(cfg: ImagenetDataset, *, img_transform=None)
+
An abstract class representing a :class:Dataset
.
All datasets that represent a map from keys to data samples should subclass
+it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a
+data sample for a given key. Subclasses could also optionally overwrite
+:meth:__len__
, which is expected to return the size of the dataset by many
+:class:~torch.utils.data.Sampler
implementations and the default options
+of :class:~torch.utils.data.DataLoader
. Subclasses could also
+optionally implement :meth:__getitems__
, for speedup batched samples
+loading. This method accepts list of indices of samples of batch and returns
+list of samples.
Note
+:class:~torch.utils.data.DataLoader
by default constructs an index
+sampler that yields integral indices.
+To make it work with a map-style
+dataset with non-integral indices/keys, a custom sampler must be provided.
@beartype.beartype
+class Imagenet(torch.utils.data.Dataset):
+ def __init__(self, cfg: config.ImagenetDataset, *, img_transform=None):
+ import datasets
+
+ self.hf_dataset = datasets.load_dataset(
+ cfg.name, split=cfg.split, trust_remote_code=True
+ )
+
+ self.img_transform = img_transform
+ self.labels = self.hf_dataset.info.features["label"].names
+
+ def __getitem__(self, i):
+ example = self.hf_dataset[i]
+ example["index"] = i
+
+ example["image"] = example["image"].convert("RGB")
+ if self.img_transform:
+ example["image"] = self.img_transform(example["image"])
+ example["target"] = example.pop("label")
+ example["label"] = self.labels[example["target"]]
+
+ return example
+
+ def __len__(self) -> int:
+ return len(self.hf_dataset)
+
class Metadata
(model_family: str, model_ckpt: str, layers: tuple[int, ...], n_patches_per_img: int, cls_token: bool, d_vit: int, seed: int, n_imgs: int, n_patches_per_shard: int, data: str)
@@ -1022,275 +1294,6 @@ Methods
-
-class TransformedAde20k
-(cfg: Ade20kDataset, *, transform=None, seg_transform=None)
-
-
-An abstract class representing a :class:Dataset
.
-All datasets that represent a map from keys to data samples should subclass
-it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a
-data sample for a given key. Subclasses could also optionally overwrite
-:meth:__len__
, which is expected to return the size of the dataset by many
-:class:~torch.utils.data.Sampler
implementations and the default options
-of :class:~torch.utils.data.DataLoader
. Subclasses could also
-optionally implement :meth:__getitems__
, for speedup batched samples
-loading. This method accepts list of indices of samples of batch and returns
-list of samples.
-
-Note
-:class:~torch.utils.data.DataLoader
by default constructs an index
-sampler that yields integral indices.
-To make it work with a map-style
-dataset with non-integral indices/keys, a custom sampler must be provided.
-
-
-
-Expand source code
-
-@beartype.beartype
-class TransformedAde20k(torch.utils.data.Dataset):
- class Sample(typing.TypedDict):
- img_path: str
- seg_path: str
- split: str
- label: str
- target: int
-
- samples: list[Sample]
-
- def __init__(
- self, cfg: config.Ade20kDataset, *, transform=None, seg_transform=None
- ):
- self.logger = logging.getLogger("ade20k")
- self.cfg = cfg
- self.img_dir = os.path.join(cfg.root, "images")
- self.seg_dir = os.path.join(cfg.root, "annotations")
- self.transform = transform
- self.seg_transform = seg_transform
-
- # Check that we have the right path.
- for subdir in ("images", "annotations"):
- if not os.path.isdir(os.path.join(cfg.root, subdir)):
- # Something is missing.
- if os.path.realpath(cfg.root).endswith(subdir):
- self.logger.warning(
- "The ADE20K root should contain 'images/' and 'annotations/' directories."
- )
- raise ValueError(f"Can't find path '{os.path.join(cfg.root, subdir)}'.")
-
- _, split_mapping = torchvision.datasets.folder.find_classes(self.img_dir)
- split_lookup: dict[int, str] = {
- value: key for key, value in split_mapping.items()
- }
- self.loader = torchvision.datasets.folder.default_loader
-
- # Load all the image paths.
- imgs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset(
- self.img_dir,
- split_mapping,
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
- )
-
- segs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset(
- self.seg_dir,
- split_mapping,
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
- )
-
- # Load all the targets, classes and mappings
- with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd:
- img_labels: list[str] = [line.split()[1] for line in fd.readlines()]
-
- label_set = sorted(set(img_labels))
- label_to_idx = {label: i for i, label in enumerate(label_set)}
-
- self.samples = [
- self.Sample(
- img_path=img_path,
- seg_path=seg_path,
- split=split_lookup[split],
- label=label,
- target=label_to_idx[label],
- )
- for (img_path, split), (seg_path, _), label in zip(imgs, segs, img_labels)
- ]
-
- def __getitem__(self, index: int) -> dict[str, object]:
- # Make a copy
- sample = dict(**self.samples[index])
-
- sample["image"] = self.loader(sample.pop("img_path"))
- if self.transform is not None:
- sample["image"] = self.transform(sample["image"])
-
- sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
- if self.seg_transform is not None:
- sample["segmentation"] = self.seg_transform(sample["segmentation"])
-
- sample["index"] = index
-
- return sample
-
- def __len__(self) -> int:
- return len(self.samples)
-
-Ancestors
-
-- torch.utils.data.dataset.Dataset
-- typing.Generic
-
-Class variables
-
-var Sample
--
-
dict() -> new empty dictionary
-dict(mapping) -> new dictionary initialized from a mapping object's
-(key, value) pairs
-dict(iterable) -> new dictionary initialized as if via:
-d = {}
-for k, v in iterable:
-d[k] = v
-dict(**kwargs) -> new dictionary initialized with the name=value pairs
-in the keyword argument list.
-For example:
-dict(one=1, two=2)
-
-var samples : list[TransformedAde20k.Sample]
--
-
-
-
-
-
-class TransformedImageFolder
-(root: Union[str, pathlib.Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = <function default_loader>, is_valid_file: Optional[Callable[[str], bool]] = None, allow_empty: bool = False)
-
-
-A generic data loader where the images are arranged in this way by default: ::
-root/dog/xxx.png
-root/dog/xxy.png
-root/dog/[...]/xxz.png
-
-root/cat/123.png
-root/cat/nsdf3.png
-root/cat/[...]/asd932_.png
-
-This class inherits from :class:~torchvision.datasets.DatasetFolder
so
-the same methods can be overridden to customize the dataset.
-Args
-
-- root (str or
pathlib.Path
): Root directory path.
-transform
: callable
, optional
-- A function/transform that takes in a PIL image
-and returns a transformed version. E.g,
transforms.RandomCrop
-target_transform
: callable
, optional
-- A function/transform that takes in the
-target and transforms it.
-loader
: callable
, optional
-- A function to load an image given its path.
-is_valid_file
: callable
, optional
-- A function that takes path of an Image file
-and check if the file is a valid file (used to check of corrupt files)
-
-allow_empty(bool, optional): If True, empty folders are considered to be valid classes.
-An error is raised on empty folders if False (default).
-Attributes:
-classes (list): List of the class names sorted alphabetically.
-class_to_idx (dict): Dict with items (class_name, class_index).
-imgs (list): List of (image path, class_index) tuples
-
-
-Expand source code
-
-@beartype.beartype
-class TransformedImageFolder(torchvision.datasets.ImageFolder):
- def __getitem__(self, index: int) -> dict[str, object]:
- """
- Args:
- index: Index
-
- Returns:
- dict with keys 'image', 'index', 'target' and 'label'.
- """
- breakpoint()
- path, target = self.samples[index]
- sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
- if self.target_transform is not None:
- target = self.target_transform(target)
-
- return {"image": sample, "target": target, "index": index}
-
-Ancestors
-
-- torchvision.datasets.folder.ImageFolder
-- torchvision.datasets.folder.DatasetFolder
-- torchvision.datasets.vision.VisionDataset
-- torch.utils.data.dataset.Dataset
-- typing.Generic
-
-
-
-class TransformedImagenet
-(cfg: ImagenetDataset, *, transform=None)
-
-
-An abstract class representing a :class:Dataset
.
-All datasets that represent a map from keys to data samples should subclass
-it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a
-data sample for a given key. Subclasses could also optionally overwrite
-:meth:__len__
, which is expected to return the size of the dataset by many
-:class:~torch.utils.data.Sampler
implementations and the default options
-of :class:~torch.utils.data.DataLoader
. Subclasses could also
-optionally implement :meth:__getitems__
, for speedup batched samples
-loading. This method accepts list of indices of samples of batch and returns
-list of samples.
-
-Note
-:class:~torch.utils.data.DataLoader
by default constructs an index
-sampler that yields integral indices.
-To make it work with a map-style
-dataset with non-integral indices/keys, a custom sampler must be provided.
-
-
-
-Expand source code
-
-@beartype.beartype
-class TransformedImagenet(torch.utils.data.Dataset):
- def __init__(self, cfg: config.ImagenetDataset, *, transform=None):
- import datasets
-
- self.hf_dataset = datasets.load_dataset(
- cfg.name, split=cfg.split, trust_remote_code=True
- )
-
- self.transform = transform
- self.labels = self.hf_dataset.info.features["label"].names
-
- def __getitem__(self, i):
- example = self.hf_dataset[i]
- example["index"] = i
-
- example["image"] = example["image"].convert("RGB")
- if self.transform:
- example["image"] = self.transform(example["image"])
- example["target"] = example.pop("label")
- example["label"] = self.labels[example["target"]]
-
- return example
-
- def __len__(self) -> int:
- return len(self.hf_dataset)
-
-Ancestors
-
-- torch.utils.data.dataset.Dataset
-- typing.Generic
-
-
class VitRecorder
(cfg: Activations, patches: slice = slice(None, None, None))
@@ -1471,6 +1474,13 @@ Methods
Classes
-
+
Ade20k
+
+
+-
Clip
-
+
ImageFolder
+
+-
+
Imagenet
+
+-
Metadata
-
-
TransformedAde20k
-
-
--
-
TransformedImageFolder
-
--
-
TransformedImagenet
-
--
VitRecorder
activations
diff --git a/docs/saev/config.html b/docs/saev/config.html
index 8121b1f..d72f60b 100644
--- a/docs/saev/config.html
+++ b/docs/saev/config.html
@@ -188,7 +188,7 @@ Class variables
class Ade20kDataset
-(root: str = './data/split')
+(root: str = './data/split', split: Literal['training', 'validation'] = 'training')
-
@@ -203,11 +203,15 @@
Class variables
root: str = os.path.join(".", "data", "split")
"""Where the class folders with images are stored."""
+ split: typing.Literal["training", "validation"] = "training"
+ """Data split."""
@property
def n_imgs(self) -> int:
- with open(os.path.join(self.root, "sceneCategories.txt")) as fd:
- return len(fd.read().split("\n"))
+ if self.split == "validation":
+ return 2000
+ else:
+ return 20210
Where the class folders with images are stored.
var split : Literal['training', 'validation']
Data split.
@property
def n_imgs(self) -> int:
- with open(os.path.join(self.root, "sceneCategories.txt")) as fd:
- return len(fd.read().split("\n"))
+ if self.split == "validation":
+ return 2000
+ else:
+ return 20210
n_imgs
root
+split
diff --git a/probing/__init__.py b/probing/__init__.py
deleted file mode 100644
index 816858d..0000000
--- a/probing/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
-Package for probing for individual features in trained sparse autoencoders.
-
-.. include:: ./description.md
-"""
diff --git a/probing/__main__.py b/probing/__main__.py
deleted file mode 100644
index a5700f7..0000000
--- a/probing/__main__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import typing
-
-import tyro
-
-from . import config
-
-
-def dump_topk(cfg: typing.Annotated[config.Topk, tyro.conf.arg(name="")]):
- from .dump_topk import main
-
- main(cfg)
-
-
-if __name__ == "__main__":
- tyro.extras.subcommand_cli_from_dict({
- "dump-topk": dump_topk,
- "nothing": lambda: print("dummy."),
- })
diff --git a/probing/config.py b/probing/config.py
deleted file mode 100644
index 10b2a15..0000000
--- a/probing/config.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import dataclasses
-import os
-import typing
-
-import beartype
-
-from saev import config
-
-
-@beartype.beartype
-@dataclasses.dataclass(frozen=True)
-class Probe:
- ckpt: str = os.path.join(".", "checkpoints", "abcdefg", "sae.pt")
- """Path to the sae.pt file."""
- data: config.DataLoad = dataclasses.field(default_factory=config.DataLoad)
- """ViT activations for probing tasks."""
- n_workers: int = 8
- """Number of dataloader workers."""
- sae_batch_size: int = 1024 * 16
- """Batch size for SAE inference."""
- device: str = "cuda"
- """Which accelerator to use."""
- images: config.ImageFolderDataset = dataclasses.field(
- default_factory=config.ImageFolderDataset
- )
- """Where the raw images are."""
- dump_to: str = os.path.join(".", "logs", "probes")
- """Where to save images."""
-
-
-@beartype.beartype
-@dataclasses.dataclass(frozen=True)
-class Topk:
- """.. todo:: document."""
-
- ckpt: str = os.path.join(".", "checkpoints", "sae.pt")
- """Path to the sae.pt file."""
- data: config.DataLoad = dataclasses.field(default_factory=config.DataLoad)
- """Data configuration."""
- images: config.ImagenetDataset | config.ImageFolderDataset = dataclasses.field(
- default_factory=config.ImagenetDataset
- )
- """Which images to use."""
- top_k: int = 16
- """How many images per SAE feature to store."""
- n_workers: int = 16
- """Number of dataloader workers."""
- topk_batch_size: int = 1024 * 16
- """Number of examples to apply top-k op to."""
- sae_batch_size: int = 1024 * 16
- """Batch size for SAE inference."""
- epsilon: float = 1e-9
- """Value to add to avoid log(0)."""
- sort_by: typing.Literal["cls", "img", "patch"] = "cls"
- """How to find the top k images. 'cls' picks images where the SAE latents of the ViT's [CLS] token are maximized without any patch highligting. 'img' picks images that maximize the sum of an SAE latent over all patches in the image, highlighting the patches. 'patch' pickes images that maximize an SAE latent over all patches (not summed), highlighting the patches and only showing unique images."""
- device: str = "cuda"
- """Which accelerator to use."""
- dump_to: str = os.path.join(".", "data")
- """Where to save data."""
- log_freq_range: tuple[float, float] = (-6.0, -2.0)
- """Log10 frequency range for which to save images."""
- log_value_range: tuple[float, float] = (-1.0, 1.0)
- """Log10 frequency range for which to save images."""
- include_latents: list[int] = dataclasses.field(default_factory=list)
- """Latents to always include, no matter what."""
-
- @property
- def root(self) -> str:
- return os.path.join(self.dump_to, f"sort_by_{self.sort_by}")
-
- @property
- def top_values_fpath(self) -> str:
- return os.path.join(self.root, "top_values.pt")
-
- @property
- def top_img_i_fpath(self) -> str:
- return os.path.join(self.root, "top_img_i.pt")
-
- @property
- def top_patch_i_fpath(self) -> str:
- return os.path.join(self.root, "top_patch_i.pt")
-
- @property
- def mean_values_fpath(self) -> str:
- return os.path.join(self.root, "mean_values.pt")
-
- @property
- def sparsity_fpath(self) -> str:
- return os.path.join(self.root, "sparsity.pt")
diff --git a/probing/description.md b/probing/description.md
deleted file mode 100644
index 5ffac45..0000000
--- a/probing/description.md
+++ /dev/null
@@ -1,18 +0,0 @@
-How can we find examples of interesting features in different checkpoints? How can we predict hypothetical trends in the way different models learn different features?
-
-In this work, we assume that manual effort is necessary. This package contains methods for trying to make it as easy as possible to identify interesting SAEs and predict trends between models.
-
-# An Observational Study of Vision Model Interpretability
-
-We have multiple ways to compare trained SAEs.
-
-1. Heuristic measures, like number of dead features, number of dense features, mean L0 norm, etc.
-2. Qualitative plots of feature frequency, mean activation value, L0-MSE tradeoff curves, etc.
-3. Manual inspection of the top K images for each feature.
-
-After proposing trends, we can construct individual probing datasets (see below).
-
-![experimental-design](docs/assets/experiment1.png)
-
-
-`probing.notebooks.l0_mse_tradeoff` is a notebook to explore the L0-MSE tradeoff as well as feature frequency and mean activation value distributions.
diff --git a/probing/dump_topk.py b/probing/dump_topk.py
deleted file mode 100644
index ab8d00c..0000000
--- a/probing/dump_topk.py
+++ /dev/null
@@ -1,440 +0,0 @@
-"""
-
-There is some important notation used only in this file to dramatically shorten variable names.
-
-Variables suffixed with `_im` refer to entire images, and variables suffixed with `_p` refer to patches.
-"""
-
-import collections.abc
-import dataclasses
-import logging
-import os
-import pickle
-import typing
-
-import beartype
-import torch
-import tqdm
-from jaxtyping import Float, Int, jaxtyped
-from PIL import Image
-from torch import Tensor
-
-from saev import activations, helpers, imaging, nn
-
-from . import config
-
-logger = logging.getLogger("webapp")
-
-
-@beartype.beartype
-def safe_load(path: str) -> object:
- return torch.load(path, map_location="cpu", weights_only=True)
-
-
-@jaxtyped(typechecker=beartype.beartype)
-def gather_batched(
- value: Float[Tensor, "batch n dim"], i: Int[Tensor, "batch k"]
-) -> Float[Tensor, "batch k dim"]:
- batch_size, n, dim = value.shape # noqa: F841
- _, k = i.shape
-
- batch_i = torch.arange(batch_size, device=value.device)[:, None].expand(-1, k)
- return value[batch_i, i]
-
-
-@jaxtyped(typechecker=beartype.beartype)
-@dataclasses.dataclass
-class GridElement:
- img: Image.Image
- label: str
- patches: Float[Tensor, " n_patches"]
-
-
-@beartype.beartype
-def make_img(elem: GridElement, *, upper: float | None = None) -> Image.Image:
- # Resize to 256x256 and crop to 224x224
- resize_size_px = (512, 512)
- resize_w_px, resize_h_px = resize_size_px
- crop_size_px = (448, 448)
- crop_w_px, crop_h_px = crop_size_px
- crop_coords_px = (
- (resize_w_px - crop_w_px) // 2,
- (resize_h_px - crop_h_px) // 2,
- (resize_w_px + crop_w_px) // 2,
- (resize_h_px + crop_h_px) // 2,
- )
-
- img = elem.img.resize(resize_size_px).crop(crop_coords_px)
- img = imaging.add_highlights(img, elem.patches.numpy(), upper=upper)
- return img
-
-
-@jaxtyped(typechecker=beartype.beartype)
-def get_new_topk(
- val1: Float[Tensor, "d_sae k"],
- i1: Int[Tensor, "d_sae k"],
- val2: Float[Tensor, "d_sae k"],
- i2: Int[Tensor, "d_sae k"],
- k: int,
-) -> tuple[Float[Tensor, "d_sae k"], Int[Tensor, "d_sae k"]]:
- """
- .. todo:: document this function.
- """
- all_val = torch.cat([val1, val2], dim=1)
- new_values, top_i = torch.topk(all_val, k=k, dim=1)
-
- all_i = torch.cat([i1, i2], dim=1)
- new_indices = torch.gather(all_i, 1, top_i)
- return new_values, new_indices
-
-
-@beartype.beartype
-def batched_idx(
- total_size: int, batch_size: int
-) -> collections.abc.Iterator[tuple[int, int]]:
- """
- Iterate over (start, end) indices for total_size examples, where end - start is at most batch_size.
-
- Args:
- total_size: total number of examples
- batch_size: maximum distance between the generated indices.
-
- Returns:
- A generator of (int, int) tuples that can slice up a list or a tensor.
- """
- for start in range(0, total_size, batch_size):
- stop = min(start + batch_size, total_size)
- yield start, stop
-
-
-@jaxtyped(typechecker=beartype.beartype)
-def get_sae_acts(
- vit_acts: Float[Tensor, "n d_vit"], sae: nn.SparseAutoencoder, cfg: config.Topk
-) -> Float[Tensor, "n d_sae"]:
- """
- Get SAE hidden layer activations for a batch of ViT activations.
-
- Args:
- vit_acts: Batch of ViT activations
- sae: Sparse autoencder.
- cfg: Experimental config.
- """
- sae_acts = []
- for start, end in batched_idx(len(vit_acts), cfg.sae_batch_size):
- _, f_x, *_ = sae(vit_acts[start:end].to(cfg.device))
- sae_acts.append(f_x)
-
- sae_acts = torch.cat(sae_acts, dim=0)
- sae_acts = sae_acts.to(cfg.device)
- return sae_acts
-
-
-@beartype.beartype
-@torch.inference_mode()
-def get_topk_cls(
- cfg: config.Topk,
-) -> tuple[
- Float[Tensor, "d_sae k 0"],
- Int[Tensor, "d_sae k"],
- Float[Tensor, " d_sae"],
- Float[Tensor, " d_sae"],
-]:
- assert cfg.sort_by == "cls"
- raise NotImplementedError()
-
-
-@beartype.beartype
-@torch.inference_mode()
-def get_topk_img(
- cfg: config.Topk,
-) -> tuple[
- Float[Tensor, "d_sae k n_patches_per_img"],
- Int[Tensor, "d_sae k"],
- Float[Tensor, " d_sae"],
- Float[Tensor, " d_sae"],
-]:
- """
- .. todo:: Document this.
- """
- assert cfg.sort_by == "img"
- assert cfg.data.patches == "patches"
-
- sae = nn.load(cfg.ckpt).to(cfg.device)
- dataset = activations.Dataset(cfg.data)
-
- top_values_p = torch.full(
- (sae.cfg.d_sae, cfg.top_k, dataset.metadata.n_patches_per_img),
- -1.0,
- device=cfg.device,
- )
- top_i_im = torch.zeros(
- (sae.cfg.d_sae, cfg.top_k), dtype=torch.int, device=cfg.device
- )
-
- sparsity = torch.zeros((sae.cfg.d_sae,), device=cfg.device)
- mean_values = torch.zeros((sae.cfg.d_sae,), device=cfg.device)
-
- batch_size = (
- cfg.topk_batch_size
- // dataset.metadata.n_patches_per_img
- * dataset.metadata.n_patches_per_img
- )
- n_imgs_per_batch = batch_size // dataset.metadata.n_patches_per_img
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=batch_size,
- shuffle=False,
- num_workers=cfg.n_workers,
- # See if you can change this to false and still pass the beartype check.
- drop_last=True,
- )
-
- logger.info("Loaded SAE and data.")
-
- for vit_acts, i_im, _ in helpers.progress(dataloader, desc="picking top-k"):
- sae_acts = get_sae_acts(vit_acts, sae, cfg).transpose(0, 1)
- mean_values += sae_acts.sum(dim=1)
- sparsity += (sae_acts > 0).sum(dim=1)
-
- values_p = sae_acts.view(sae.cfg.d_sae, -1, dataset.metadata.n_patches_per_img)
- values_im = values_p.sum(axis=-1)
- i_im = torch.sort(torch.unique(i_im)).values
-
- # Checks that I did my reshaping correctly.
- assert values_p.shape[1] == i_im.shape[0]
- assert len(i_im) == n_imgs_per_batch
-
- # Pick out the top 16 images for each latent in this batch.
- values_im, i = torch.topk(values_im, k=cfg.top_k, dim=1)
- # Update patch-level values
- shape_in = (
- sae.cfg.d_sae * n_imgs_per_batch,
- dataset.metadata.n_patches_per_img,
- )
- shape_out = (sae.cfg.d_sae, cfg.top_k, dataset.metadata.n_patches_per_img)
- values_p = values_p.reshape(shape_in)[i.view(-1)].reshape(shape_out)
- # Update image indices
- i_im = i_im.to(cfg.device)[i.view(-1)].view(i.shape)
-
- # Pick out the top 16 images for each latent overall.
- top_values_im = top_values_p.sum(axis=-1)
- all_values_p = torch.cat((top_values_p, values_p), dim=1)
- all_values_im = torch.cat((top_values_im, values_im), dim=1)
- _, j = torch.topk(all_values_im, k=cfg.top_k, dim=1)
-
- shape_in = (sae.cfg.d_sae * cfg.top_k * 2, dataset.metadata.n_patches_per_img)
- top_values_p = all_values_p.reshape(shape_in)[j.view(-1)].reshape(
- top_values_p.shape
- )
-
- all_top_i = torch.cat((top_i_im, i_im), dim=1)
- top_i_im = torch.gather(all_top_i, 1, j)
-
- mean_values /= sparsity
- sparsity /= len(dataset)
-
- return top_values_p, top_i_im, mean_values, sparsity
-
-
-@beartype.beartype
-@torch.inference_mode()
-def get_topk_patch(
- cfg: config.Topk,
-) -> tuple[
- Float[Tensor, "d_sae k n_patches_per_img"],
- Int[Tensor, "d_sae k"],
- Float[Tensor, " d_sae"],
- Float[Tensor, " d_sae"],
-]:
- """
- Gets the top k images for each latent in the SAE.
- The top k images are for latent i are sorted by
-
- max over all patches: f_x(patch)[i]
-
- Thus, we could end up with duplicate images in the top k, if an image has more than one patch that maximally activates an SAE latent.
-
- Args:
- cfg: Config.
-
- Returns:
-
- """
- assert cfg.sort_by == "patch"
- assert cfg.data.patches == "patches"
-
- sae = nn.load(cfg.ckpt).to(cfg.device)
- dataset = activations.Dataset(cfg.data)
-
- top_values_p = torch.full(
- (sae.cfg.d_sae, cfg.top_k, dataset.metadata.n_patches_per_img),
- -1.0,
- device=cfg.device,
- )
- top_i_im = torch.zeros(
- (sae.cfg.d_sae, cfg.top_k), dtype=torch.int, device=cfg.device
- )
-
- sparsity = torch.zeros((sae.cfg.d_sae,), device=cfg.device)
- mean_values = torch.zeros((sae.cfg.d_sae,), device=cfg.device)
-
- batch_size = (
- cfg.topk_batch_size
- // dataset.metadata.n_patches_per_img
- * dataset.metadata.n_patches_per_img
- )
- n_imgs_per_batch = batch_size // dataset.metadata.n_patches_per_img
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=batch_size,
- shuffle=False,
- num_workers=cfg.n_workers,
- # See if you can change this to false and still pass the beartype check.
- drop_last=True,
- )
-
- logger.info("Loaded SAE and data.")
-
- for vit_acts, i_im, _ in helpers.progress(dataloader, desc="picking top-k"):
- sae_acts = get_sae_acts(vit_acts, sae, cfg).transpose(0, 1)
- mean_values += sae_acts.sum(dim=1)
- sparsity += (sae_acts > 0).sum(dim=1)
-
- i_im = torch.sort(torch.unique(i_im)).values
- values_p = sae_acts.view(
- sae.cfg.d_sae, len(i_im), dataset.metadata.n_patches_per_img
- )
-
- # Checks that I did my reshaping correctly.
- assert values_p.shape[1] == i_im.shape[0]
- assert len(i_im) == n_imgs_per_batch
-
- _, k = torch.topk(sae_acts, k=cfg.top_k, dim=1)
- k_im = k // dataset.metadata.n_patches_per_img
-
- values_p = gather_batched(values_p, k_im)
- i_im = i_im.to(cfg.device)[k_im]
-
- all_values_p = torch.cat((top_values_p, values_p), axis=1)
- _, k = torch.topk(all_values_p.max(axis=-1).values, k=cfg.top_k, axis=1)
-
- top_values_p = gather_batched(all_values_p, k)
- top_i_im = torch.gather(torch.cat((top_i_im, i_im), axis=1), 1, k)
-
- mean_values /= sparsity
- sparsity /= len(dataset)
-
- return top_values_p, top_i_im, mean_values, sparsity
-
-
-@beartype.beartype
-@torch.inference_mode()
-def dump_activations(cfg: config.Topk):
- """
- For each SAE latent, we want to know which images have the most total "activation".
- That is, we keep track of each patch
- """
- if cfg.sort_by == "img":
- top_values_p, top_img_i, mean_values, sparsity = get_topk_img(cfg)
- elif cfg.sort_by == "cls":
- top_values_p, top_img_i, mean_values, sparsity = get_topk_cls(cfg)
- elif cfg.sort_by == "patch":
- top_values_p, top_img_i, mean_values, sparsity = get_topk_patch(cfg)
- else:
- typing.assert_never(cfg.sort_by)
-
- os.makedirs(cfg.root, exist_ok=True)
-
- torch.save(top_values_p, cfg.top_values_fpath)
- torch.save(top_img_i, cfg.top_img_i_fpath)
- torch.save(mean_values, cfg.mean_values_fpath)
- torch.save(sparsity, cfg.sparsity_fpath)
-
-
-@beartype.beartype
-@torch.inference_mode()
-def main(cfg: config.Topk):
- """
- .. todo:: document this function.
-
- Dump top-k images to a directory.
-
- Args:
- cfg: Configuration object.
- """
-
- try:
- top_values_p = safe_load(cfg.top_values_fpath)
- sparsity = safe_load(cfg.sparsity_fpath)
- mean_values = safe_load(cfg.mean_values_fpath)
- top_i = safe_load(cfg.top_img_i_fpath)
- except FileNotFoundError as err:
- logger.warning("Need to dump files: %s", err)
- dump_activations(cfg)
- return main(cfg)
-
- d_sae, cached_topk, n_patches = top_values_p.shape
- # Check that the data is at least shaped correctly.
- assert cfg.top_k == cached_topk
- if cfg.sort_by == "cls":
- assert n_patches == 0
- elif cfg.sort_by == "img":
- assert n_patches > 0
- elif cfg.sort_by == "patch":
- assert n_patches > 0
- else:
- typing.assert_never(cfg.sort_by)
-
- logger.info("Loaded sorted data.")
-
- dataset = activations.get_dataset(cfg.images, transform=None)
-
- min_log_freq, max_log_freq = cfg.log_freq_range
- min_log_value, max_log_value = cfg.log_value_range
- # breakpoint()
- mask = (
- (min_log_freq < torch.log10(sparsity))
- & (torch.log10(sparsity) < max_log_freq)
- & (min_log_value < torch.log10(mean_values))
- & (torch.log10(mean_values) < max_log_value)
- )
-
- neuron_i = torch.arange(d_sae)[mask.cpu()].tolist()
-
- for i in tqdm.tqdm(neuron_i, desc="saving visuals"):
- neuron_dir = os.path.join(cfg.root, "neurons", str(i))
- os.makedirs(neuron_dir, exist_ok=True)
-
- # Image grid
- elems = []
- seen_i_im = set()
- for i_im, values_p in zip(top_i[i].tolist(), top_values_p[i]):
- if i_im in seen_i_im:
- continue
- example = dataset[i_im]
- elem = GridElement(example["image"], example["label"], values_p)
- elems.append(elem)
-
- seen_i_im.add(i_im)
-
- # How to scale values.
- upper = None
- if top_values_p[i].numel() > 0:
- upper = top_values_p[i].max().item()
-
- for i, elem in enumerate(elems):
- img = make_img(elem, upper=upper)
- img.save(os.path.join(neuron_dir, f"{i}.png"))
- with open(os.path.join(neuron_dir, f"{i}.txt"), "w") as fd:
- fd.write(elem.label + "\n")
-
- # Metadata
- metadata = {
- "neuron": i,
- "log10 sparsity": torch.log10(sparsity)[i].item(),
- "mean activation": mean_values[i].item(),
- }
- with open(f"{neuron_dir}/metadata.pkl", "wb") as pickle_file:
- pickle.dump(metadata, pickle_file)
diff --git a/probing/logbook.py b/probing/logbook.py
deleted file mode 100644
index 3d2777b..0000000
--- a/probing/logbook.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import marimo
-
-__generated_with = "0.9.14"
-app = marimo.App(width="full")
-
-
-@app.cell
-def __():
- import marimo as mo
-
- return (mo,)
-
-
-@app.cell
-def __(mo):
- mo.md(
- r"""
- How can we find examples of interesting features in different checkpoints? How can we predict hypothetical trends in the way different models learn different features?
-
- In this paper, we assume that manual effort is necessary. This notebook is a method for trying to make it as easy as possible to identify interesting SAEs and predict trends between models.
- """
- )
- return
-
-
-@app.cell
-def __(mo):
- mo.md(
- r"""
- # An Observational Study of Vision Model Interpretability
-
- We have multiple ways to compare trained SAEs.
-
- 1. Heuristic measures, like number of dead features, number of dense features, mean L0 norm, etc.
- 2. Qualitative plots of feature frequency, mean activation value, L0-MSE tradeoff curves, etc.
- 3. Manual inspection of the top K images for each feature.
-
- After proposing trends, we can construct individual probing datasets (see below).
- """
- )
- return
-
-
-@app.cell
-def __(mo):
- mo.image("docs/assets/experiment1.png")
- return
-
-
-@app.cell
-def __(mo):
- mo.md(r"""`probing/notebooks/lo_mse_tradeoff.py` is a notebook to explore""")
- return
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/probing/notebooks/l0_mse_tradeoff.py b/probing/notebooks/l0_mse_tradeoff.py
deleted file mode 100644
index deb26a1..0000000
--- a/probing/notebooks/l0_mse_tradeoff.py
+++ /dev/null
@@ -1,374 +0,0 @@
-import marimo
-
-__generated_with = "0.9.14"
-app = marimo.App(
- width="medium",
- css_file="/home/stevens.994/.config/marimo/custom.css",
-)
-
-
-@app.cell
-def __():
- import json
- import os
-
- import altair as alt
- import beartype
- import marimo as mo
- import matplotlib.pyplot as plt
- import numpy as np
- import polars as pl
- from jaxtyping import Float, jaxtyped
-
- import wandb
-
- return Float, alt, beartype, jaxtyped, json, mo, np, os, pl, plt, wandb
-
-
-@app.cell
-def __(mo):
- mo.md(
- r"""I want to know how points along the reconstruction-fidelity frontier vary in their sparsity-value heatmap. Then I can look at how these heatmaps differ as I change hyperparameters like normalizing \(W_\text{dec}\), etc."""
- )
- return
-
-
-@app.cell
-def __():
- tag = "reproduction-v1.0"
- return (tag,)
-
-
-@app.cell
-def __(alt, df, mo):
- chart = mo.ui.altair_chart(
- alt.Chart(
- df.select(
- "summary/eval/l0",
- "summary/losses/mse",
- "id",
- "config/sae/sparsity_coeff",
- "config/lr",
- )
- )
- .mark_point()
- .encode(
- x=alt.X("summary/eval/l0"),
- y=alt.Y("summary/losses/mse"),
- tooltip=["id"],
- shape="config/lr:N",
- color="config/sae/sparsity_coeff:N",
- )
- )
- chart
- return (chart,)
-
-
-@app.cell
-def __(chart, df, mo, np, plot_dist, plt):
- sub_df = (
- df.join(chart.value.select("id"), on="id", how="inner")
- .sort(by="summary/eval/l0")
- .select("id", "summary/eval/freqs", "summary/eval/mean_values")
- .head(4)
- )
-
- mo.stop(len(sub_df) == 0, "Select one or more points.")
-
- scatter_fig, scatter_axes = plt.subplots(
- ncols=len(sub_df), figsize=(12, 3), squeeze=False, sharey=True, sharex=True
- )
-
- hist_fig, hist_axes = plt.subplots(
- ncols=len(sub_df),
- nrows=2,
- figsize=(12, 6),
- squeeze=False,
- sharey=True,
- sharex=True,
- )
-
- # Always one row
- scatter_axes = scatter_axes.reshape(-1)
- hist_axes = hist_axes.T
-
- for (id, freqs, values), scatter_ax, (freq_hist_ax, values_hist_ax) in zip(
- sub_df.iter_rows(), scatter_axes, hist_axes
- ):
- plot_dist(
- freqs.astype(float),
- (-6.0, 0.0),
- values.astype(float),
- (-2.0, 2.0),
- scatter_ax,
- )
- # ax.scatter(freqs, values, marker=".", alpha=0.03)
- # ax.set_yscale("log")
- # ax.set_xscale("log")
- scatter_ax.set_title(id)
-
- # Plot feature
- bins = np.linspace(-6, 1, 100)
- freq_hist_ax.hist(np.log10(freqs.astype(float)), bins=bins)
- freq_hist_ax.set_title(f"{id} Feat. Freq. Dist.")
-
- values_hist_ax.hist(np.log10(values.astype(float)), bins=bins)
- values_hist_ax.set_title(f"{id} Mean Val. Distribution")
-
- scatter_fig.tight_layout()
- hist_fig.tight_layout()
- return (
- bins,
- freq_hist_ax,
- freqs,
- hist_axes,
- hist_fig,
- id,
- scatter_ax,
- scatter_axes,
- scatter_fig,
- sub_df,
- values,
- values_hist_ax,
- )
-
-
-@app.cell
-def __(scatter_fig):
- scatter_fig
- return
-
-
-@app.cell
-def __(hist_fig):
- hist_fig
- return
-
-
-@app.cell
-def __(chart, df, pl):
- df.join(chart.value.select("id"), on="id", how="inner").sort(
- by="summary/eval/l0"
- ).select("id", pl.selectors.starts_with("config/"))
- return
-
-
-@app.cell
-def __(Float, beartype, jaxtyped, np):
- @jaxtyped(typechecker=beartype.beartype)
- def plot_dist(
- freqs: Float[np.ndarray, " d_sae"],
- freqs_log_range: tuple[float, float],
- values: Float[np.ndarray, " d_sae"],
- values_log_range: tuple[float, float],
- ax,
- ):
- log_sparsity = np.log10(freqs + 1e-9)
- log_values = np.log10(values + 1e-9)
-
- mask = np.ones(len(log_sparsity)).astype(bool)
- min_log_freq, max_log_freq = freqs_log_range
- mask[log_sparsity < min_log_freq] = False
- mask[log_sparsity > max_log_freq] = False
- min_log_value, max_log_value = values_log_range
- mask[log_values < min_log_value] = False
- mask[log_values > max_log_value] = False
-
- n_shown = mask.sum()
- ax.scatter(
- log_sparsity[mask],
- log_values[mask],
- marker=".",
- alpha=0.1,
- color="tab:blue",
- label=f"Shown ({n_shown})",
- )
- n_filtered = (~mask).sum()
- ax.scatter(
- log_sparsity[~mask],
- log_values[~mask],
- marker=".",
- alpha=0.1,
- color="tab:red",
- label=f"Filtered ({n_filtered})",
- )
-
- ax.axvline(min_log_freq, linewidth=0.5, color="tab:red")
- ax.axvline(max_log_freq, linewidth=0.5, color="tab:red")
- ax.axhline(min_log_value, linewidth=0.5, color="tab:red")
- ax.axhline(max_log_value, linewidth=0.5, color="tab:red")
-
- ax.set_xlabel("Feature Frequency (log10)")
- # ax.set_ylabel("Mean Activation Value (log10)")
-
- return (plot_dist,)
-
-
-@app.cell
-def __(
- beartype,
- get_data_key,
- get_model_key,
- json,
- load_freqs,
- load_mean_values,
- os,
- pl,
- tag,
- wandb,
-):
- @beartype.beartype
- def make_df(tag: str):
- runs = wandb.Api().runs(path="samuelstevens/saev", filters={"config.tag": tag})
-
- rows = []
- for run in runs:
- row = {}
- row["id"] = run.id
-
- row.update(**{
- f"summary/{key}": value for key, value in run.summary.items()
- })
- try:
- row["summary/eval/freqs"] = load_freqs(run)
- except ValueError:
- print(f"Run {run.id} did not log eval/freqs.")
- continue
- except RuntimeError:
- print(f"Wandb blew up on run {run.id}.")
- continue
- try:
- row["summary/eval/mean_values"] = load_mean_values(run)
- except ValueError:
- print(f"Run {run.id} did not log eval/mean_values.")
- continue
- except RuntimeError:
- print(f"Wandb blew up on run {run.id}.")
- continue
-
- # config
- row.update(**{
- f"config/data/{key}": value
- for key, value in run.config.pop("data").items()
- })
- row.update(**{
- f"config/sae/{key}": value
- for key, value in run.config.pop("sae").items()
- })
-
- row.update(**{f"config/{key}": value for key, value in run.config.items()})
-
- with open(
- os.path.join(row["config/data/shard_root"], "metadata.json")
- ) as fd:
- metadata = json.load(fd)
-
- row["model_key"] = get_model_key(metadata)
-
- data_key = get_data_key(metadata)
- if data_key is None:
- print("Bad run: {run.id}")
- continue
- row["data_key"] = data_key
-
- row["config/d_vit"] = metadata["d_vit"]
- rows.append(row)
-
- if not rows:
- return None
-
- df = pl.DataFrame(rows).with_columns(
- (pl.col("config/sae/d_vit") * pl.col("config/sae/exp_factor")).alias(
- "config/sae/d_sae"
- )
- )
- return df
-
- df = make_df(tag)
- return df, make_df
-
-
-@app.cell
-def __(beartype):
- @beartype.beartype
- def get_model_key(metadata: dict[str, object]) -> str | None:
- if (
- metadata["model_org"] == "dinov2"
- and metadata["model_ckpt"] == "dinov2_vitb14_reg"
- ):
- return "DINOv2 ViT-B/14"
- if (
- metadata["model_org"] == "open-clip" or metadata["model_org"] == "clip"
- ) and metadata["model_ckpt"] == "ViT-B-16/openai":
- return "CLIP ViT-B/16"
-
- if (
- metadata["model_org"] == "clip"
- and metadata["model_ckpt"] == "ViT-L-14/openai"
- ):
- return "CLIP ViT-L/14"
-
- print(f"Unknown model: {(metadata['model_org'], metadata['model_ckpt'])}")
- return None
-
- @beartype.beartype
- def get_data_key(metadata: dict[str, object]) -> str | None:
- if "train_mini" in metadata["data"] and "Inat21Dataset" in metadata["data"]:
- return "iNat21"
-
- if "train" in metadata["data"] and "Imagenet" in metadata["data"]:
- return "ImageNet-1K"
-
- print(f"Unknown data: {metadata['data']}")
- return None
-
- return get_data_key, get_model_key
-
-
-@app.cell
-def __(Float, json, np, os):
- def load_freqs(run) -> Float[np.ndarray, " d_sae"]:
- try:
- for artifact in run.logged_artifacts():
- if "evalfreqs" not in artifact.name:
- continue
-
- dpath = artifact.download()
- fpath = os.path.join(dpath, "eval", "freqs.table.json")
- print(fpath)
- with open(fpath) as fd:
- raw = json.load(fd)
- return np.array(raw["data"]).reshape(-1)
- except Exception as err:
- raise RuntimeError("Wandb sucks.") from err
-
- raise ValueError(f"freqs not found in run '{run.id}'")
-
- def load_mean_values(run) -> Float[np.ndarray, " d_sae"]:
- try:
- for artifact in run.logged_artifacts():
- if "evalmean_values" not in artifact.name:
- continue
-
- dpath = artifact.download()
- fpath = os.path.join(dpath, "eval", "mean_values.table.json")
- print(fpath)
- with open(fpath) as fd:
- raw = json.load(fd)
- return np.array(raw["data"]).reshape(-1)
- except Exception as err:
- raise RuntimeError("Wandb sucks.") from err
-
- raise ValueError(f"mean_values not found in run '{run.id}'")
-
- return load_freqs, load_mean_values
-
-
-@app.cell
-def __(df):
- df
- return
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/saev/activations.py b/saev/activations.py
index 9459a23..7ac72fc 100644
--- a/saev/activations.py
+++ b/saev/activations.py
@@ -9,6 +9,7 @@
2. Multiple [n_imgs_per_shard, n_layers, (n_patches + 1), d_vit] tensors. This is a set of sharded activations.
"""
+import collections.abc
import dataclasses
import hashlib
import json
@@ -204,7 +205,7 @@ def make_vit(cfg: config.Activations):
@beartype.beartype
-def make_img_transform(model_family: str, model_ckpt: str) -> typing.Callable:
+def make_img_transform(model_family: str, model_ckpt: str) -> collections.abc.Callable:
if model_family == "clip" or model_family == "siglip":
import open_clip
@@ -518,33 +519,33 @@ def setup_ade20k(cfg: config.Activations):
@beartype.beartype
-def get_dataset(cfg: config.DatasetConfig, *, transform):
+def get_dataset(cfg: config.DatasetConfig, *, img_transform):
"""
Gets the dataset for the current experiment; delegates construction to dataset-specific functions.
Args:
cfg: Experiment config.
- transform: Image transform to be applied to each image.
+ img_transform: Image transform to be applied to each image.
Returns:
A dataset that has dictionaries with `'image'`, `'index'`, `'target'`, and `'label'` keys containing examples.
"""
if isinstance(cfg, config.ImagenetDataset):
- return TransformedImagenet(cfg, transform=transform)
+ return Imagenet(cfg, img_transform=img_transform)
elif isinstance(cfg, config.Ade20kDataset):
- return TransformedAde20k(cfg, transform=transform)
+ return Ade20k(cfg, img_transform=img_transform)
else:
typing.assert_never(cfg)
@beartype.beartype
-def get_dataloader(cfg: config.Activations, preprocess):
+def get_dataloader(cfg: config.Activations, *, img_transform=None):
"""
Gets the dataloader for the current experiment; delegates dataloader construction to dataset-specific functions.
Args:
cfg: Experiment config.
- preprocess: Image transform to be applied to each image.
+ img_transform: Image transform to be applied to each image.
Returns:
A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches.
@@ -553,7 +554,7 @@ def get_dataloader(cfg: config.Activations, preprocess):
cfg.data,
(config.ImagenetDataset, config.ImageFolderDataset, config.Ade20kDataset),
):
- dataloader = get_default_dataloader(cfg, preprocess)
+ dataloader = get_default_dataloader(cfg, img_transform=img_transform)
else:
typing.assert_never(cfg.data)
@@ -562,19 +563,19 @@ def get_dataloader(cfg: config.Activations, preprocess):
@beartype.beartype
def get_default_dataloader(
- cfg: config.Activations, transform
+ cfg: config.Activations, *, img_transform: collections.abc.Callable
) -> torch.utils.data.DataLoader:
"""
Get a dataloader for a default map-style dataset.
Args:
cfg: Config.
- preprocess: Image transform to be applied to each image.
+ img_transform: Image transform to be applied to each image.
Returns:
A PyTorch Dataloader that yields dictionaries with `'image'` keys containing image batches, `'index'` keys containing original dataset indices and `'label'` keys containing label batches.
"""
- dataset = get_dataset(cfg.data, transform=transform)
+ dataset = get_dataset(cfg.data, img_transform=img_transform)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
@@ -589,15 +590,15 @@ def get_default_dataloader(
@beartype.beartype
-class TransformedImagenet(torch.utils.data.Dataset):
- def __init__(self, cfg: config.ImagenetDataset, *, transform=None):
+class Imagenet(torch.utils.data.Dataset):
+ def __init__(self, cfg: config.ImagenetDataset, *, img_transform=None):
import datasets
self.hf_dataset = datasets.load_dataset(
cfg.name, split=cfg.split, trust_remote_code=True
)
- self.transform = transform
+ self.img_transform = img_transform
self.labels = self.hf_dataset.info.features["label"].names
def __getitem__(self, i):
@@ -605,8 +606,8 @@ def __getitem__(self, i):
example["index"] = i
example["image"] = example["image"].convert("RGB")
- if self.transform:
- example["image"] = self.transform(example["image"])
+ if self.img_transform:
+ example["image"] = self.img_transform(example["image"])
example["target"] = example.pop("label")
example["label"] = self.labels[example["target"]]
@@ -617,7 +618,7 @@ def __len__(self) -> int:
@beartype.beartype
-class TransformedImageFolder(torchvision.datasets.ImageFolder):
+class ImageFolder(torchvision.datasets.ImageFolder):
def __getitem__(self, index: int) -> dict[str, object]:
"""
Args:
@@ -629,8 +630,8 @@ def __getitem__(self, index: int) -> dict[str, object]:
breakpoint()
path, target = self.samples[index]
sample = self.loader(path)
- if self.transform is not None:
- sample = self.transform(sample)
+ if self.img_transform is not None:
+ sample = self.img_transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
@@ -638,24 +639,29 @@ def __getitem__(self, index: int) -> dict[str, object]:
@beartype.beartype
-class TransformedAde20k(torch.utils.data.Dataset):
- class Sample(typing.TypedDict):
+class Ade20k(torch.utils.data.Dataset):
+ @beartype.beartype
+ @dataclasses.dataclass(frozen=True)
+ class Sample:
img_path: str
seg_path: str
- split: str
label: str
target: int
samples: list[Sample]
def __init__(
- self, cfg: config.Ade20kDataset, *, transform=None, seg_transform=None
+ self,
+ cfg: config.Ade20kDataset,
+ *,
+ img_transform: collections.abc.Callable | None = None,
+ seg_transform: collections.abc.Callable | None = lambda x: None,
):
self.logger = logging.getLogger("ade20k")
self.cfg = cfg
self.img_dir = os.path.join(cfg.root, "images")
self.seg_dir = os.path.join(cfg.root, "annotations")
- self.transform = transform
+ self.img_transform = img_transform
self.seg_transform = seg_transform
# Check that we have the right path.
@@ -674,18 +680,28 @@ def __init__(
}
self.loader = torchvision.datasets.folder.default_loader
+ assert cfg.split in set(split_lookup.values())
+
# Load all the image paths.
- imgs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset(
- self.img_dir,
- split_mapping,
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
- )
+ imgs: list[str] = [
+ path
+ for path, s in torchvision.datasets.folder.make_dataset(
+ self.img_dir,
+ split_mapping,
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
+ )
+ if split_lookup[s] == cfg.split
+ ]
- segs: list[tuple[str, int]] = torchvision.datasets.folder.make_dataset(
- self.seg_dir,
- split_mapping,
- extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
- )
+ segs: list[str] = [
+ path
+ for path, s in torchvision.datasets.folder.make_dataset(
+ self.seg_dir,
+ split_mapping,
+ extensions=torchvision.datasets.folder.IMG_EXTENSIONS,
+ )
+ if split_lookup[s] == cfg.split
+ ]
# Load all the targets, classes and mappings
with open(os.path.join(cfg.root, "sceneCategories.txt")) as fd:
@@ -695,27 +711,25 @@ def __init__(
label_to_idx = {label: i for i, label in enumerate(label_set)}
self.samples = [
- self.Sample(
- img_path=img_path,
- seg_path=seg_path,
- split=split_lookup[split],
- label=label,
- target=label_to_idx[label],
- )
- for (img_path, split), (seg_path, _), label in zip(imgs, segs, img_labels)
+ self.Sample(img_path, seg_path, label, label_to_idx[label])
+ for img_path, seg_path, label in zip(imgs, segs, img_labels)
]
def __getitem__(self, index: int) -> dict[str, object]:
- # Make a copy
- sample = dict(**self.samples[index])
+ # Convert to dict.
+ sample = dataclasses.asdict(self.samples[index])
sample["image"] = self.loader(sample.pop("img_path"))
- if self.transform is not None:
- sample["image"] = self.transform(sample["image"])
+ if self.img_transform is not None:
+ image = self.img_transform(sample.pop("image"))
+ if image is not None:
+ sample["image"] = image
sample["segmentation"] = Image.open(sample.pop("seg_path")).convert("L")
if self.seg_transform is not None:
- sample["segmentation"] = self.seg_transform(sample["segmentation"])
+ segmentation = self.seg_transform(sample.pop("segmentation"))
+ if segmentation is not None:
+ sample["segmentation"] = segmentation
sample["index"] = index
@@ -790,7 +804,7 @@ def worker_fn(cfg: config.Activations):
vit = make_vit(cfg)
img_transform = make_img_transform(cfg.model_family, cfg.model_ckpt)
- dataloader = get_dataloader(cfg, img_transform)
+ dataloader = get_dataloader(cfg, img_transform=img_transform)
writer = ShardWriter(cfg)
diff --git a/saev/config.py b/saev/config.py
index 649932b..2652bfd 100644
--- a/saev/config.py
+++ b/saev/config.py
@@ -65,11 +65,15 @@ class Ade20kDataset:
root: str = os.path.join(".", "data", "split")
"""Where the class folders with images are stored."""
+ split: typing.Literal["training", "validation"] = "training"
+ """Data split."""
@property
def n_imgs(self) -> int:
- with open(os.path.join(self.root, "sceneCategories.txt")) as fd:
- return len(fd.read().split("\n"))
+ if self.split == "validation":
+ return 2000
+ else:
+ return 20210
DatasetConfig = ImagenetDataset | ImageFolderDataset | Ade20kDataset