Skip to content

Commit

Permalink
Trying to reproduce original ViT-B/16 with ViT-L/14
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstevens committed Nov 20, 2024
1 parent 9d4cae4 commit 03067f9
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 313 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ wandb/
artifacts/
checkpoints/
logs/
web/
.hypothesis/
13 changes: 6 additions & 7 deletions configs/preprint/baseline.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
lr = 0.0004
lr = [1e-4, 3e-4, 1e-3]

n_lr_warmup = 500
n_sparsity_warmup = 0
n_sparsity_warmup = 500

tag = "baseline-v3.0"
tag = "baseline-v4.1"

[sae]
sparsity_coeff = 0.00008
sparsity_coeff = [4e-5, 8e-5, 1.6e-4]
ghost_grads = false
normalize_w_dec = [false, true]
remove_parallel_grads = [false, true]
normalize_w_dec = true
remove_parallel_grads = true
exp_factor = 16
n_reinit_samples = [0, 524_288]
6 changes: 3 additions & 3 deletions justfile
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
docs: lint
uv run pdoc3 --force --html --output-dir docs --config latex_math=True saev
uv run pdoc3 --force --html --output-dir docs --config latex_math=True saev probing

test: lint
uv run pytest saev
uv run pytest saev probing

lint: fmt
ruff check saev/ main.py webapp.py
fd -e py | xargs ruff check

fmt:
fd -e py | xargs isort
Expand Down
26 changes: 13 additions & 13 deletions llms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@ I'm writing a submission to ICML. The premise is that we apply sparse autoencode
My current outline is

1. Introduction
1. We want to interpret foundation vision models
2. We want to see examples of the concepts being represented.
3. SAEs are the best way to do this.
4. We apply SAEs to DINOv2 and CLIP vision models and find some neat stuff.
1.1. Vision foundation models are crucial for modern computer vision but require interpretability for high-stakes scenarios.
1.2. Foundation models learn rich internal representations, supported by theoretical results (information bottleneck, nuisance variables) and empirical evidence from interpretability methods.
1.3. Understanding learned concepts is challenging due to limitations of current methods (individual neurons vs distributed representations) and superposition of features.
1.4. An ideal interpretability method should: (a) work on foundation models, (b) explain concepts being represented, and (c) be computationally efficient.
1.5. Current methods (saliency maps, ProtoPNet) fail to meet these requirements due to various limitations.
1.6. Sparse autoencoders (SAEs) satisfy all requirements for an ideal interpretability method.
1.7. We analyze CLIP and DINOv2 using SAEs, revealing how different training objectives lead to systematic differences in learned representations.
2. Related work.
3. How we train the SAEs (technical details, easy to write)
3.1. We train ReLU autoencoders, minimizing both MSE reconstruction and L1 sparsity of the intermediate activations, parameterized by lambda.
3.2. We train on 100M patch-level activations from the second to last layer of a model M on a dataset D, typically DINOv2 on ImageNet-1K.
4. Findings
1. CLIP learns abstract semantic relationships, DINOv2 doesn't.
2. DINOv2 identifies morphological traits in animals much more often than CLIP.
3. Training an SAE on one datset transfers to a new dataset.
4.1. Different pre-training supervision methods learn different features: we compare vision-language models like CLIP and SigLIP to vision-only models like DINOv2 and V-JEPA. Language-vision models learn abstract semantic relationships, vision-only don't.
4.2. (unknown section title) If we know the learned features of vision models, and the features we expect to be relevant for a new downstream task, we can predict which vision models will do better on a downstream task.
4.3. (unknown title) SAE-recovered features can effectively steer classifications on downstream tasks.
5. Conclusion & Future Work

---

I also like the following setup for the paper:


Vision foundation models (like DINOv2 and CLIP) are widely used as feature extractors across diverse tasks. However, **their success hinges on their internal representations capturing the right concepts for downstream decision-making.**

With respect to writing, we want to frame everything as Goal->Problem->Solution. In general, I want you to be skeptical and challenging of arguments that are not supported by evidence.
With respect to writing, we want to frame everything as Goal->Problem->Solution. In general, I want you to be extremely skeptical and challenge any arguments that are not supported by evidence.

Some questions that come up that are not in the outline yet:

Expand Down
127 changes: 114 additions & 13 deletions logbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -707,21 +707,122 @@ What patterns am I seeing now?



# 11/14/2024

I'm writing a submission to CVPR. The premise is that we apply sparse autoencoders to vision models to interpret their internal representations.
Anthropic discusses ways to search for particular features [here](https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html#searching).
In general, it seems using combinations of positive/negative examples and filtering by "fire/no fire" is good, rather than the LDA-based classifier I have now.

My current outline is
# 11/18/2024

1. Introduction: we want to interpret foundation vision models and see examples of the concepts being represented. SAEs are the best way to do this. We apply SAEs to DINOv2 and CLIP vision models and find some neat stuff.
2. Related work.
3. How we train the SAEs (technical details, easy to write)
4. Findings
1. CLIP learns abstract semantic relationships, DINOv2 doesn't.
2. DINOv2 identifies morphological traits in animals much more often than CLIP.
3. Training an SAE on one datset transfers to a new dataset.
5. Conclusion & Future Work
### Influential Topics and Papers in Model Debugging

# 11/14/2024
Here’s a breakdown of key topics and landmark papers to guide your search:

Anthropic discusses ways to search for particular features [here](https://transformer-circuits.pub/2024/scaling-monosemanticity/index.html#searching).
In general, it seems using combinations of positive/negative examples and filtering by "fire/no fire" is good, rather than the LDA-based classifier I have now.
---

### 1. **Interpretability and Conceptual Debugging**
- **Papers**:
- *“Distill and Detect: Mapping Knowledge Representations for Neural Networks”* by Hinton et al. Introduced distillation, with implications for debugging via knowledge transfer.
- *“Network Dissection: Quantifying Interpretability of Deep Visual Representations”* by Bau et al. (2017). Introduced a method for attributing neurons to human-understandable concepts.
- *“Testing with Concept Activation Vectors (TCAV)”* by Kim et al. (2018). Uses concept vectors for debugging biases and assessing representation relevance.

- **Search Terms**:
- "Concept-based interpretability"
- "Neuron-level interpretability"
- "Network dissection techniques"

---

### 2. **Feature Attribution**
- **Papers**:
- *“Integrated Gradients”* by Sundararajan et al. (2017). A widely-used method for explaining model predictions via input gradients.
- *“SHAP: A Unified Approach to Interpreting Model Predictions”* by Lundberg and Lee (2017). Provides local explanations for individual predictions.
- *“Grad-CAM”* by Selvaraju et al. (2017). Explains model decisions by visualizing class-specific activation maps.

- **Search Terms**:
- "Feature attribution methods"
- "Saliency maps"
- "Explainability in deep learning"

---

### 3. **Debugging via Counterfactuals**
- **Papers**:
- *“Counterfactual Explanations without Opening the Black Box”* by Wachter et al. (2017). Explores generating counterfactual examples for model debugging.
- *“Robustness Disparities by Design?”* by D’Amour et al. (2020). Discusses biases exposed via adversarial counterfactuals.

- **Search Terms**:
- "Counterfactual model debugging"
- "Robustness and adversarial testing"

---

### 4. **Bias and Fairness Debugging**
- **Papers**:
- *“A Framework for Understanding Unintended Consequences of Machine Learning”* by Suresh and Guttag (2019). Highlights sources of bias and debugging strategies.
- *“The Mythos of Model Interpretability”* by Lipton (2016). Discusses the trade-offs and pitfalls in debugging interpretability.

- **Search Terms**:
- "Bias debugging in ML models"
- "Fairness-aware machine learning"

---

### 5. **Debugging Neural Representations**
- **Papers**:
- *“Probing Neural Representations”* by Conneau et al. (2018). Uses diagnostic classifiers to analyze internal representations.
- *“Representation Erasure”* by Elazar and Goldberg (2018). Tests the influence of specific features by removing them and observing performance.

- **Search Terms**:
- "Probing neural representations"
- "Feature erasure methods"

---

### 6. **Model Behavior and Failure Analysis**
- **Papers**:
- *“A Taxonomy of Machine Learning Failures”* by Barredo Arrieta et al. (2020). Categorizes failure modes and discusses debugging strategies.
- *“Debugging Deep Models with Logical Constraints”* by Narayanan et al. (2018). Formal methods for pinpointing and fixing logic violations in neural network outputs.

- **Search Terms**:
- "Failure modes in ML"
- "Behavioral analysis of neural models"

---

### 7. **Transfer Learning and Representation Debugging**
- **Papers**:
- *“Do Better ImageNet Models Transfer Better?”* by Kornblith et al. (2019). Studies how representation quality affects transfer learning.
- *“Representation Learning with Contrastive Predictive Coding”* by Oord et al. (2018). Debugging representation learning via contrastive methods.

- **Search Terms**:
- "Transfer learning debugging"
- "Contrastive representation learning"

---

### Practical Suggestions
Use combinations of terms like *“debugging deep models,”* *“interpretability techniques,”* and *“bias evaluation in neural networks”* to drill down further. Sites like [Papers with Code](https://paperswithcode.com/) and ArXiv categories (e.g., cs.LG, cs.CV) can quickly lead you to relevant papers.

Would you like to focus on any particular subdomain or technique?


https://huggingface.co/docs/transformers/main/en/tasks/semantic_segmentation
https://github.com/facebookresearch/dinov2/issues/25


# 11/19/2024

For some reason when I trained on ViT-L/14 activations nothing worked.
So we need to debug that.
There are several changes

* ViT-B/16 -> ViT-L/14
* Removed b_dec re-init
* Scaled mean activation norm to approximately sqrt(d_vit)
* Subtracted approximate mean activation to center activations
* Various code changes :(

So I will re-record ViT-B/16 activations, then re-add the b_dec re-init, ignore the scale_norm and scale_mean and *pray* that it works again.
Then I will re-add the original changes, debugging what went wrong at each step.
It might just be a learning rate thing.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def run_imagenet1k():

jobs = []
# jobs.append(executor.submit(run_histograms))
jobs.append(executor.submit(run_broden))
# jobs.append(executor.submit(run_imagenet1k))
# jobs.append(executor.submit(run_broden))
jobs.append(executor.submit(run_imagenet1k))
for job in jobs:
job.result()

Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ description = "Sparse autoencoders for vision transformers in PyTorch"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"altair>=5.4.1",
"beartype>=0.19.0",
"datasets>=3.0.1",
"einops>=0.8.0",
"img2dataset",
"jaxtyping>=0.2.34",
"jupyterlab>=4.3.0",
"marimo>=0.9.10",
"matplotlib>=3.9.2",
"open-clip-torch>=2.28.0",
"pillow>=11.0.0",
"polars>=1.12.0",
"submitit>=1.5.2",
"torch>=2.5.0",
"tqdm>=4.66.5",
Expand All @@ -27,13 +29,12 @@ ignore = ["F722"]

[tool.uv]
dev-dependencies = [
"hypothesis>=6.119.0",
"hypothesis-torch>=0.8.4",
"pdoc3>=0.11.1",
"pytest>=8.3.3",
]

[tool.uv.sources]
img2dataset = { path = "../img2dataset", editable = true }

[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
Expand Down
Loading

0 comments on commit 03067f9

Please sign in to comment.