Skip to content

Commit

Permalink
Improve pipeline test coverage (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 10, 2023
1 parent 4e488b0 commit a495994
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 6 deletions.
1 change: 1 addition & 0 deletions sparse_autoencoder/source_data/mock_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,4 @@ def __init__(
dataset_split: Dataset split (e.g. `train`).
"""
self.dataset = ConsecutiveIntHuggingFaceDataset(context_size=context_size) # type: ignore
self.context_size = context_size
16 changes: 11 additions & 5 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def validate_sae(self, validation_number_activations: int) -> None:
)
for metric in self.metrics.validation_metrics:
calculated = metric.calculate(validation_data)
wandb.log(data=calculated, commit=False)
if wandb.run is not None:
wandb.log(data=calculated, commit=False)

@final
def save_checkpoint(self) -> None:
Expand Down Expand Up @@ -411,10 +412,15 @@ def run_pipeline(
)

if parameter_updates is not None:
wandb.log(
{"resample/dead_neurons": len(parameter_updates.dead_neuron_indices)},
commit=False,
)
if wandb.run is not None:
wandb.log(
{
"resample/dead_neurons": len(
parameter_updates.dead_neuron_indices
)
},
commit=False,
)

# Update the parameters
self.update_parameters(parameter_updates)
Expand Down
93 changes: 92 additions & 1 deletion sparse_autoencoder/train/tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test the pipeline module."""
from typing import Any
from unittest.mock import MagicMock

import pytest
import torch
Expand All @@ -15,7 +17,12 @@
from sparse_autoencoder.activation_resampler.abstract_activation_resampler import (
ParameterUpdateResults,
)
from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.metrics.validate.abstract_validate_metric import (
AbstractValidationMetric,
ValidationMetricData,
)
from sparse_autoencoder.source_data.mock_dataset import MockDataset


Expand All @@ -39,9 +46,10 @@ def pipeline_fixture() -> Pipeline:
named_parameters=autoencoder.named_parameters(),
)
source_data = MockDataset(context_size=100)
activation_resampler = ActivationResampler(n_learned_features=autoencoder.n_learned_features)

return Pipeline(
activation_resampler=None,
activation_resampler=activation_resampler,
autoencoder=autoencoder,
cache_name="blocks.0.hook_mlp_out",
layer=0,
Expand Down Expand Up @@ -239,3 +247,86 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
dtype=torch.float,
),
), "Optimizer non-dead neuron state should not have changed after training."


class TestValidateSAE:
"""Test the validate_sae method."""

def test_reconstruction_loss_more_than_base(self, pipeline_fixture: Pipeline) -> None:
"""Test that the reconstruction loss is more than the base loss."""

# Create a dummy metric, so we can retrieve the stored data afterwards
class StoreValidationMetric(AbstractValidationMetric):
"""Dummy metric to store the data."""

data: ValidationMetricData | None

def calculate(self, data: ValidationMetricData) -> dict[str, Any]:
"""Store the data."""
self.data = data
return {}

dummy_metric = StoreValidationMetric()
pipeline_fixture.metrics.validation_metrics.append(dummy_metric)

# Run the validation loop
store_size: int = 1000
pipeline_fixture.generate_activations(store_size)
pipeline_fixture.validate_sae(store_size)

# Check the loss
assert (
dummy_metric.data is not None
), "Dummy metric should have stored the data from the validation loop."
assert (
dummy_metric.data.source_model_loss_with_reconstruction
> dummy_metric.data.source_model_loss
), "Reconstruction loss should be more than base loss."

assert (
dummy_metric.data.source_model_loss_with_zero_ablation
> dummy_metric.data.source_model_loss
), "Zero ablation loss should be more than base loss."


class TestRunPipeline:
"""Test the run_pipeline method."""

def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> None:
"""Test that the run_pipeline method calls all the other methods."""
pipeline_fixture.validate_sae = MagicMock(spec=Pipeline.validate_sae) # type: ignore
pipeline_fixture.save_checkpoint = MagicMock(spec=Pipeline.save_checkpoint) # type: ignore
pipeline_fixture.activation_resampler.step_resampler = MagicMock( # type: ignore
spec=ActivationResampler.step_resampler, return_value=None
)

store_size = 1000
context_size = pipeline_fixture.source_dataset.context_size
train_batch_size = store_size // context_size

total_loops = 5
validate_expected_calls = 2
checkpoint_expected_calls = 5

pipeline_fixture.run_pipeline(
train_batch_size=train_batch_size,
max_store_size=store_size,
max_activations=store_size * 5,
validation_number_activations=store_size,
validate_frequency=store_size * (total_loops // validate_expected_calls),
checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls),
)

# Check the number of calls
assert (
pipeline_fixture.validate_sae.call_count == validate_expected_calls
), f"Validate should have been called {validate_expected_calls} times."

assert (
pipeline_fixture.save_checkpoint.call_count == checkpoint_expected_calls
), f"Checkpoint should have been called {checkpoint_expected_calls} times."

assert (pipeline_fixture.activation_resampler) is not None
assert (
pipeline_fixture.activation_resampler.step_resampler.call_count == total_loops
), f"Resampler should have been called {total_loops} times."

0 comments on commit a495994

Please sign in to comment.