generated from alan-cooney/transformer-lens-starter-template
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests against Neel's anthropic paper comment implementation (#122)
- Loading branch information
1 parent
130ee59
commit 0bd0c9d
Showing
4 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
"capturable", | ||
"categoricalwprobabilities", | ||
"circuitsvis", | ||
"coeff", | ||
"colab", | ||
"cuda", | ||
"cudnn", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
203 changes: 203 additions & 0 deletions
203
sparse_autoencoder/autoencoder/components/tests/test_compare_neel_implementation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
"""Compare the SAE implementation to Neel's 1L Implementation. | ||
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py | ||
""" | ||
import torch | ||
from torch import nn | ||
|
||
from sparse_autoencoder.autoencoder.model import SparseAutoencoder | ||
|
||
|
||
class NeelAutoencoder(nn.Module): | ||
"""Neel's 1L autoencoder implementation.""" | ||
|
||
def __init__( | ||
self, | ||
d_hidden: int, | ||
act_size: int, | ||
l1_coeff: float, | ||
dtype: torch.dtype = torch.float32, | ||
) -> None: | ||
"""Initialize the autoencoder.""" | ||
super().__init__() | ||
self.b_dec = nn.Parameter(torch.zeros(act_size, dtype=dtype)) | ||
self.W_enc = nn.Parameter( | ||
torch.nn.init.kaiming_uniform_(torch.empty(act_size, d_hidden, dtype=dtype)) | ||
) | ||
self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype)) | ||
self.W_dec = nn.Parameter( | ||
torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, act_size, dtype=dtype)) | ||
) | ||
|
||
self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) | ||
|
||
self.d_hidden = d_hidden | ||
self.l1_coeff = l1_coeff | ||
|
||
def forward( | ||
self, x: torch.Tensor | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Forward pass.""" | ||
x_cent = x - self.b_dec | ||
acts = nn.functional.relu(x_cent @ self.W_enc + self.b_enc) | ||
x_reconstruct = acts @ self.W_dec + self.b_dec | ||
l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0) | ||
l1_loss = self.l1_coeff * (acts.float().abs().sum()) | ||
loss = l2_loss + l1_loss | ||
return loss, x_reconstruct, acts, l2_loss, l1_loss | ||
|
||
def make_decoder_weights_and_grad_unit_norm(self) -> None: | ||
"""Make decoder weights and gradient unit norm.""" | ||
with torch.no_grad(): | ||
weight_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) | ||
weight_dec_grad_proj = (self.W_dec.grad * weight_dec_normed).sum( | ||
-1, keepdim=True | ||
) * weight_dec_normed | ||
self.W_dec.grad -= weight_dec_grad_proj | ||
# Bugfix(?) | ||
self.W_dec.data = weight_dec_normed | ||
|
||
|
||
def test_biases_initialised_same_way() -> None: | ||
"""Test that the biases are initialised the same.""" | ||
n_input_features: int = 2 | ||
n_learned_features: int = 3 | ||
l1_coefficient: float = 0.01 | ||
|
||
torch.random.manual_seed(0) | ||
autoencoder = SparseAutoencoder( | ||
n_input_features=n_input_features, | ||
n_learned_features=n_learned_features, | ||
) | ||
|
||
torch.random.manual_seed(0) | ||
neel_autoencoder = NeelAutoencoder( | ||
d_hidden=n_learned_features, | ||
act_size=n_input_features, | ||
l1_coeff=l1_coefficient, | ||
) | ||
|
||
assert torch.allclose(autoencoder.tied_bias, neel_autoencoder.b_dec) | ||
# Note we can't compare weights as Neel's implementation uses rotated tensors and applies | ||
# kaiming incorrectly (uses leaky relu version and incorrect fan strategy for the rotation | ||
# used). Note also that the encoder bias is initialised to zero in Neel's implementation, | ||
# whereas we use the standard PyTorch initialisation. | ||
|
||
|
||
def test_forward_pass_same_weights() -> None: | ||
"""Test a forward pass with the same weights.""" | ||
n_input_features: int = 12 | ||
n_learned_features: int = 48 | ||
l1_coefficient: float = 0.01 | ||
|
||
autoencoder = SparseAutoencoder( | ||
n_input_features=n_input_features, | ||
n_learned_features=n_learned_features, | ||
) | ||
neel_autoencoder = NeelAutoencoder( | ||
d_hidden=n_learned_features, | ||
act_size=n_input_features, | ||
l1_coeff=l1_coefficient, | ||
) | ||
|
||
# Set the same weights | ||
autoencoder.encoder.weight.data = neel_autoencoder.W_enc.data.T | ||
autoencoder.decoder.weight.data = neel_autoencoder.W_dec.data.T | ||
autoencoder.tied_bias.data = neel_autoencoder.b_dec.data | ||
autoencoder.encoder.bias.data = neel_autoencoder.b_enc.data | ||
|
||
# Create some test data | ||
test_batch = torch.randn(4, n_input_features) | ||
learned, hidden = autoencoder.forward(test_batch) | ||
_loss, x_reconstruct, acts, _l2_loss, _l1_loss = neel_autoencoder.forward(test_batch) | ||
|
||
assert torch.allclose(learned, acts) | ||
assert torch.allclose(hidden, x_reconstruct) | ||
|
||
|
||
def test_unit_norm_weights() -> None: | ||
"""Test that the decoder weights are unit normalized in the same way.""" | ||
n_input_features: int = 2 | ||
n_learned_features: int = 4 | ||
l1_coefficient: float = 0.01 | ||
|
||
autoencoder = SparseAutoencoder( | ||
n_input_features=n_input_features, | ||
n_learned_features=n_learned_features, | ||
) | ||
neel_autoencoder = NeelAutoencoder( | ||
d_hidden=n_learned_features, | ||
act_size=n_input_features, | ||
l1_coeff=l1_coefficient, | ||
) | ||
pre_unit_norm_weights = autoencoder.decoder.weight.clone() | ||
pre_unit_norm_neel_weights = neel_autoencoder.W_dec.clone() | ||
|
||
# Set the same decoder weights | ||
decoder_weights = torch.rand_like(autoencoder.decoder.weight) | ||
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore | ||
neel_autoencoder.W_dec.data = decoder_weights.T | ||
|
||
# Do a forward & backward pass so we have gradients | ||
test_batch = torch.randn(4, n_input_features) | ||
_learned, decoded = autoencoder.forward(test_batch) | ||
decoded.sum().backward() | ||
decoded = neel_autoencoder.forward(test_batch)[1] | ||
decoded.sum().backward() | ||
|
||
# Apply the unit norm | ||
autoencoder.decoder.constrain_weights_unit_norm() | ||
neel_autoencoder.make_decoder_weights_and_grad_unit_norm() | ||
|
||
# Check the decoder weights are the same with both models | ||
assert torch.allclose(autoencoder.decoder.weight, neel_autoencoder.W_dec.T) | ||
|
||
# Check the trivial case that the weights haven't just stayed the same as before the unit norm | ||
assert not torch.allclose(autoencoder.decoder.weight, pre_unit_norm_weights) | ||
assert not torch.allclose(neel_autoencoder.W_dec, pre_unit_norm_neel_weights) | ||
|
||
|
||
def test_unit_norm_weights_grad() -> None: | ||
"""Test that the decoder weights are unit normalized in the same way.""" | ||
torch.random.manual_seed(42) | ||
n_input_features: int = 2 | ||
n_learned_features: int = 4 | ||
l1_coefficient: float = 0.01 | ||
|
||
autoencoder = SparseAutoencoder( | ||
n_input_features=n_input_features, | ||
n_learned_features=n_learned_features, | ||
) | ||
neel_autoencoder = NeelAutoencoder( | ||
d_hidden=n_learned_features, | ||
act_size=n_input_features, | ||
l1_coeff=l1_coefficient, | ||
) | ||
|
||
# Set the same decoder weights | ||
decoder_weights = torch.rand_like(autoencoder.decoder.weight) | ||
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore | ||
neel_autoencoder.W_dec.data = decoder_weights.T | ||
autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore | ||
neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec) | ||
|
||
# Set the same tied bias weights | ||
neel_autoencoder.b_dec.data = autoencoder.tied_bias.data | ||
neel_autoencoder.b_enc.data = autoencoder.encoder.bias.data | ||
neel_autoencoder.W_enc.data = autoencoder.encoder.weight.data.T | ||
|
||
# Do a forward & backward pass so we have gradients | ||
test_batch = torch.randn(4, n_input_features) | ||
_learned, decoded = autoencoder.forward(test_batch) | ||
_loss = decoded.sum().backward() | ||
neel_decoded = neel_autoencoder.forward(test_batch)[1] | ||
_loss_neel = neel_decoded.sum().backward() | ||
|
||
# Apply the unit norm | ||
autoencoder.decoder.constrain_weights_unit_norm() | ||
neel_autoencoder.make_decoder_weights_and_grad_unit_norm() | ||
|
||
# Check the gradient weights are the same | ||
assert autoencoder.decoder.weight.grad is not None | ||
assert neel_autoencoder.W_dec.grad is not None | ||
assert torch.allclose(autoencoder.decoder.weight.grad, neel_autoencoder.W_dec.grad.T, rtol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
"""Tests against Neel's Autoencoder Loss. | ||
Compare module output against Neel's implementation at | ||
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py . | ||
""" | ||
from typing import TypedDict | ||
|
||
import pytest | ||
import torch | ||
|
||
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss | ||
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss | ||
from sparse_autoencoder.loss.reducer import LossReducer | ||
from sparse_autoencoder.tensor_types import ( | ||
InputOutputActivationBatch, | ||
ItemTensor, | ||
LearnedActivationBatch, | ||
) | ||
|
||
|
||
def neel_loss( | ||
source_activations: InputOutputActivationBatch, | ||
learned_activations: LearnedActivationBatch, | ||
decoded_activations: InputOutputActivationBatch, | ||
l1_coefficient: float, | ||
) -> tuple[ItemTensor, ItemTensor, ItemTensor]: | ||
"""Neel's loss function.""" | ||
l2_loss = (decoded_activations.float() - source_activations.float()).pow(2).sum(-1).mean(0) | ||
l1_loss = l1_coefficient * (learned_activations.float().abs().sum()) | ||
loss = l2_loss + l1_loss | ||
return l1_loss, l2_loss, loss | ||
|
||
|
||
def lib_loss( | ||
source_activations: InputOutputActivationBatch, | ||
learned_activations: LearnedActivationBatch, | ||
decoded_activations: InputOutputActivationBatch, | ||
l1_coefficient: float, | ||
) -> tuple[ItemTensor, ItemTensor, ItemTensor]: | ||
"""This library's loss function.""" | ||
l1_loss_fn = LearnedActivationsL1Loss( | ||
l1_coefficient=float(l1_coefficient), | ||
) | ||
l2_loss_fn = L2ReconstructionLoss() | ||
|
||
loss_fn = LossReducer(l1_loss_fn, l2_loss_fn) | ||
|
||
l1_loss = l1_loss_fn.forward(source_activations, learned_activations, decoded_activations) | ||
l2_loss = l2_loss_fn.forward(source_activations, learned_activations, decoded_activations) | ||
total_loss = loss_fn.forward(source_activations, learned_activations, decoded_activations) | ||
return l1_loss.sum(), l2_loss.sum(), total_loss.sum() | ||
|
||
|
||
class MockActivations(TypedDict): | ||
"""Mock activations.""" | ||
|
||
source_activations: InputOutputActivationBatch | ||
learned_activations: LearnedActivationBatch | ||
decoded_activations: InputOutputActivationBatch | ||
|
||
|
||
@pytest.fixture() | ||
def mock_activations() -> MockActivations: | ||
"""Create mock activations. | ||
Returns: | ||
Tuple of source activations, learned activations, and decoded activations. | ||
""" | ||
source_activations = torch.rand(10, 20) | ||
learned_activations = torch.rand(10, 50) | ||
decoded_activations = torch.rand(10, 20) | ||
return { | ||
"source_activations": source_activations, | ||
"learned_activations": learned_activations, | ||
"decoded_activations": decoded_activations, | ||
} | ||
|
||
|
||
def test_l1_loss_the_same(mock_activations: MockActivations) -> None: | ||
"""Test that the L1 loss is the same.""" | ||
l1_coefficient: float = 0.01 | ||
|
||
neel_l1_loss = neel_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[0] | ||
|
||
lib_l1_loss = lib_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[0].sum() | ||
|
||
assert torch.allclose(neel_l1_loss, lib_l1_loss) | ||
|
||
|
||
def test_l2_loss_the_same(mock_activations: MockActivations) -> None: | ||
"""Test that the L2 loss is the same.""" | ||
l1_coefficient: float = 0.01 | ||
|
||
neel_l2_loss = neel_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[1] | ||
|
||
lib_l2_loss = lib_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[1].sum() | ||
|
||
# Fix for the fact that Neel's L2 loss is summed across the features dimension and then averaged | ||
# across the batch. By contrast for l1 it is summed across both features and batch dimensions. | ||
neel_l2_loss_fixed = neel_l2_loss * len(mock_activations["source_activations"]) | ||
|
||
assert torch.allclose(neel_l2_loss_fixed, lib_l2_loss) | ||
|
||
|
||
@pytest.mark.skip("We believe Neel's L2 approach is different to the original paper.") | ||
def test_total_loss_the_same(mock_activations: MockActivations) -> None: | ||
"""Test that the total loss is the same.""" | ||
l1_coefficient: float = 0.01 | ||
|
||
neel_total_loss = neel_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[2].sum() | ||
|
||
lib_total_loss = lib_loss( | ||
source_activations=mock_activations["source_activations"], | ||
learned_activations=mock_activations["learned_activations"], | ||
decoded_activations=mock_activations["decoded_activations"], | ||
l1_coefficient=l1_coefficient, | ||
)[2].sum() | ||
|
||
assert torch.allclose(neel_total_loss, lib_total_loss) |