Skip to content

Commit

Permalink
Add tests against Neel's anthropic paper comment implementation (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 29, 2023
1 parent 130ee59 commit 0bd0c9d
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 0 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"capturable",
"categoricalwprobabilities",
"circuitsvis",
"coeff",
"colab",
"cuda",
"cudnn",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
addopts=[
"--doctest-modules",
"--jaxtyping-packages=sparse_autoencoder,beartype.beartype",
"-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning",
"-s",
]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Compare the SAE implementation to Neel's 1L Implementation.
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py
"""
import torch
from torch import nn

from sparse_autoencoder.autoencoder.model import SparseAutoencoder


class NeelAutoencoder(nn.Module):
"""Neel's 1L autoencoder implementation."""

def __init__(
self,
d_hidden: int,
act_size: int,
l1_coeff: float,
dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the autoencoder."""
super().__init__()
self.b_dec = nn.Parameter(torch.zeros(act_size, dtype=dtype))
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(torch.empty(act_size, d_hidden, dtype=dtype))
)
self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, act_size, dtype=dtype))
)

self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

self.d_hidden = d_hidden
self.l1_coeff = l1_coeff

def forward(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass."""
x_cent = x - self.b_dec
acts = nn.functional.relu(x_cent @ self.W_enc + self.b_enc)
x_reconstruct = acts @ self.W_dec + self.b_dec
l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
l1_loss = self.l1_coeff * (acts.float().abs().sum())
loss = l2_loss + l1_loss
return loss, x_reconstruct, acts, l2_loss, l1_loss

def make_decoder_weights_and_grad_unit_norm(self) -> None:
"""Make decoder weights and gradient unit norm."""
with torch.no_grad():
weight_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
weight_dec_grad_proj = (self.W_dec.grad * weight_dec_normed).sum(
-1, keepdim=True
) * weight_dec_normed
self.W_dec.grad -= weight_dec_grad_proj
# Bugfix(?)
self.W_dec.data = weight_dec_normed


def test_biases_initialised_same_way() -> None:
"""Test that the biases are initialised the same."""
n_input_features: int = 2
n_learned_features: int = 3
l1_coefficient: float = 0.01

torch.random.manual_seed(0)
autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)

torch.random.manual_seed(0)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
act_size=n_input_features,
l1_coeff=l1_coefficient,
)

assert torch.allclose(autoencoder.tied_bias, neel_autoencoder.b_dec)
# Note we can't compare weights as Neel's implementation uses rotated tensors and applies
# kaiming incorrectly (uses leaky relu version and incorrect fan strategy for the rotation
# used). Note also that the encoder bias is initialised to zero in Neel's implementation,
# whereas we use the standard PyTorch initialisation.


def test_forward_pass_same_weights() -> None:
"""Test a forward pass with the same weights."""
n_input_features: int = 12
n_learned_features: int = 48
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
act_size=n_input_features,
l1_coeff=l1_coefficient,
)

# Set the same weights
autoencoder.encoder.weight.data = neel_autoencoder.W_enc.data.T
autoencoder.decoder.weight.data = neel_autoencoder.W_dec.data.T
autoencoder.tied_bias.data = neel_autoencoder.b_dec.data
autoencoder.encoder.bias.data = neel_autoencoder.b_enc.data

# Create some test data
test_batch = torch.randn(4, n_input_features)
learned, hidden = autoencoder.forward(test_batch)
_loss, x_reconstruct, acts, _l2_loss, _l1_loss = neel_autoencoder.forward(test_batch)

assert torch.allclose(learned, acts)
assert torch.allclose(hidden, x_reconstruct)


def test_unit_norm_weights() -> None:
"""Test that the decoder weights are unit normalized in the same way."""
n_input_features: int = 2
n_learned_features: int = 4
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
act_size=n_input_features,
l1_coeff=l1_coefficient,
)
pre_unit_norm_weights = autoencoder.decoder.weight.clone()
pre_unit_norm_neel_weights = neel_autoencoder.W_dec.clone()

# Set the same decoder weights
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
neel_autoencoder.W_dec.data = decoder_weights.T

# Do a forward & backward pass so we have gradients
test_batch = torch.randn(4, n_input_features)
_learned, decoded = autoencoder.forward(test_batch)
decoded.sum().backward()
decoded = neel_autoencoder.forward(test_batch)[1]
decoded.sum().backward()

# Apply the unit norm
autoencoder.decoder.constrain_weights_unit_norm()
neel_autoencoder.make_decoder_weights_and_grad_unit_norm()

# Check the decoder weights are the same with both models
assert torch.allclose(autoencoder.decoder.weight, neel_autoencoder.W_dec.T)

# Check the trivial case that the weights haven't just stayed the same as before the unit norm
assert not torch.allclose(autoencoder.decoder.weight, pre_unit_norm_weights)
assert not torch.allclose(neel_autoencoder.W_dec, pre_unit_norm_neel_weights)


def test_unit_norm_weights_grad() -> None:
"""Test that the decoder weights are unit normalized in the same way."""
torch.random.manual_seed(42)
n_input_features: int = 2
n_learned_features: int = 4
l1_coefficient: float = 0.01

autoencoder = SparseAutoencoder(
n_input_features=n_input_features,
n_learned_features=n_learned_features,
)
neel_autoencoder = NeelAutoencoder(
d_hidden=n_learned_features,
act_size=n_input_features,
l1_coeff=l1_coefficient,
)

# Set the same decoder weights
decoder_weights = torch.rand_like(autoencoder.decoder.weight)
autoencoder.decoder._weight.data = decoder_weights # noqa: SLF001 # type: ignore
neel_autoencoder.W_dec.data = decoder_weights.T
autoencoder.decoder._weight.grad = torch.zeros_like(autoencoder.decoder.weight) # noqa: SLF001 # type: ignore
neel_autoencoder.W_dec.grad = torch.zeros_like(neel_autoencoder.W_dec)

# Set the same tied bias weights
neel_autoencoder.b_dec.data = autoencoder.tied_bias.data
neel_autoencoder.b_enc.data = autoencoder.encoder.bias.data
neel_autoencoder.W_enc.data = autoencoder.encoder.weight.data.T

# Do a forward & backward pass so we have gradients
test_batch = torch.randn(4, n_input_features)
_learned, decoded = autoencoder.forward(test_batch)
_loss = decoded.sum().backward()
neel_decoded = neel_autoencoder.forward(test_batch)[1]
_loss_neel = neel_decoded.sum().backward()

# Apply the unit norm
autoencoder.decoder.constrain_weights_unit_norm()
neel_autoencoder.make_decoder_weights_and_grad_unit_norm()

# Check the gradient weights are the same
assert autoencoder.decoder.weight.grad is not None
assert neel_autoencoder.W_dec.grad is not None
assert torch.allclose(autoencoder.decoder.weight.grad, neel_autoencoder.W_dec.grad.T, rtol=1e-4)
144 changes: 144 additions & 0 deletions sparse_autoencoder/loss/tests/test_neel_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Tests against Neel's Autoencoder Loss.
Compare module output against Neel's implementation at
https://github.com/neelnanda-io/1L-Sparse-Autoencoder/blob/main/utils.py .
"""
from typing import TypedDict

import pytest
import torch

from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
from sparse_autoencoder.loss.reducer import LossReducer
from sparse_autoencoder.tensor_types import (
InputOutputActivationBatch,
ItemTensor,
LearnedActivationBatch,
)


def neel_loss(
source_activations: InputOutputActivationBatch,
learned_activations: LearnedActivationBatch,
decoded_activations: InputOutputActivationBatch,
l1_coefficient: float,
) -> tuple[ItemTensor, ItemTensor, ItemTensor]:
"""Neel's loss function."""
l2_loss = (decoded_activations.float() - source_activations.float()).pow(2).sum(-1).mean(0)
l1_loss = l1_coefficient * (learned_activations.float().abs().sum())
loss = l2_loss + l1_loss
return l1_loss, l2_loss, loss


def lib_loss(
source_activations: InputOutputActivationBatch,
learned_activations: LearnedActivationBatch,
decoded_activations: InputOutputActivationBatch,
l1_coefficient: float,
) -> tuple[ItemTensor, ItemTensor, ItemTensor]:
"""This library's loss function."""
l1_loss_fn = LearnedActivationsL1Loss(
l1_coefficient=float(l1_coefficient),
)
l2_loss_fn = L2ReconstructionLoss()

loss_fn = LossReducer(l1_loss_fn, l2_loss_fn)

l1_loss = l1_loss_fn.forward(source_activations, learned_activations, decoded_activations)
l2_loss = l2_loss_fn.forward(source_activations, learned_activations, decoded_activations)
total_loss = loss_fn.forward(source_activations, learned_activations, decoded_activations)
return l1_loss.sum(), l2_loss.sum(), total_loss.sum()


class MockActivations(TypedDict):
"""Mock activations."""

source_activations: InputOutputActivationBatch
learned_activations: LearnedActivationBatch
decoded_activations: InputOutputActivationBatch


@pytest.fixture()
def mock_activations() -> MockActivations:
"""Create mock activations.
Returns:
Tuple of source activations, learned activations, and decoded activations.
"""
source_activations = torch.rand(10, 20)
learned_activations = torch.rand(10, 50)
decoded_activations = torch.rand(10, 20)
return {
"source_activations": source_activations,
"learned_activations": learned_activations,
"decoded_activations": decoded_activations,
}


def test_l1_loss_the_same(mock_activations: MockActivations) -> None:
"""Test that the L1 loss is the same."""
l1_coefficient: float = 0.01

neel_l1_loss = neel_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[0]

lib_l1_loss = lib_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[0].sum()

assert torch.allclose(neel_l1_loss, lib_l1_loss)


def test_l2_loss_the_same(mock_activations: MockActivations) -> None:
"""Test that the L2 loss is the same."""
l1_coefficient: float = 0.01

neel_l2_loss = neel_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[1]

lib_l2_loss = lib_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[1].sum()

# Fix for the fact that Neel's L2 loss is summed across the features dimension and then averaged
# across the batch. By contrast for l1 it is summed across both features and batch dimensions.
neel_l2_loss_fixed = neel_l2_loss * len(mock_activations["source_activations"])

assert torch.allclose(neel_l2_loss_fixed, lib_l2_loss)


@pytest.mark.skip("We believe Neel's L2 approach is different to the original paper.")
def test_total_loss_the_same(mock_activations: MockActivations) -> None:
"""Test that the total loss is the same."""
l1_coefficient: float = 0.01

neel_total_loss = neel_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[2].sum()

lib_total_loss = lib_loss(
source_activations=mock_activations["source_activations"],
learned_activations=mock_activations["learned_activations"],
decoded_activations=mock_activations["decoded_activations"],
l1_coefficient=l1_coefficient,
)[2].sum()

assert torch.allclose(neel_total_loss, lib_total_loss)

0 comments on commit 0bd0c9d

Please sign in to comment.