From a495994037ab5c928789750776aa29c051fd4a08 Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Sun, 10 Dec 2023 15:53:51 -0300 Subject: [PATCH] Improve pipeline test coverage (#148) --- .../source_data/mock_dataset.py | 1 + sparse_autoencoder/train/pipeline.py | 16 +++- .../train/tests/test_pipeline.py | 93 ++++++++++++++++++- 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/sparse_autoencoder/source_data/mock_dataset.py b/sparse_autoencoder/source_data/mock_dataset.py index de974bb8..1fedb5e3 100644 --- a/sparse_autoencoder/source_data/mock_dataset.py +++ b/sparse_autoencoder/source_data/mock_dataset.py @@ -174,3 +174,4 @@ def __init__( dataset_split: Dataset split (e.g. `train`). """ self.dataset = ConsecutiveIntHuggingFaceDataset(context_size=context_size) # type: ignore + self.context_size = context_size diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index fcd15140..e9a63f5d 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -335,7 +335,8 @@ def validate_sae(self, validation_number_activations: int) -> None: ) for metric in self.metrics.validation_metrics: calculated = metric.calculate(validation_data) - wandb.log(data=calculated, commit=False) + if wandb.run is not None: + wandb.log(data=calculated, commit=False) @final def save_checkpoint(self) -> None: @@ -411,10 +412,15 @@ def run_pipeline( ) if parameter_updates is not None: - wandb.log( - {"resample/dead_neurons": len(parameter_updates.dead_neuron_indices)}, - commit=False, - ) + if wandb.run is not None: + wandb.log( + { + "resample/dead_neurons": len( + parameter_updates.dead_neuron_indices + ) + }, + commit=False, + ) # Update the parameters self.update_parameters(parameter_updates) diff --git a/sparse_autoencoder/train/tests/test_pipeline.py b/sparse_autoencoder/train/tests/test_pipeline.py index 59a9d9fd..80f616e1 100644 --- a/sparse_autoencoder/train/tests/test_pipeline.py +++ b/sparse_autoencoder/train/tests/test_pipeline.py @@ -1,4 +1,6 @@ """Test the pipeline module.""" +from typing import Any +from unittest.mock import MagicMock import pytest import torch @@ -15,7 +17,12 @@ from sparse_autoencoder.activation_resampler.abstract_activation_resampler import ( ParameterUpdateResults, ) +from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore +from sparse_autoencoder.metrics.validate.abstract_validate_metric import ( + AbstractValidationMetric, + ValidationMetricData, +) from sparse_autoencoder.source_data.mock_dataset import MockDataset @@ -39,9 +46,10 @@ def pipeline_fixture() -> Pipeline: named_parameters=autoencoder.named_parameters(), ) source_data = MockDataset(context_size=100) + activation_resampler = ActivationResampler(n_learned_features=autoencoder.n_learned_features) return Pipeline( - activation_resampler=None, + activation_resampler=activation_resampler, autoencoder=autoencoder, cache_name="blocks.0.hook_mlp_out", layer=0, @@ -239,3 +247,86 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None: dtype=torch.float, ), ), "Optimizer non-dead neuron state should not have changed after training." + + +class TestValidateSAE: + """Test the validate_sae method.""" + + def test_reconstruction_loss_more_than_base(self, pipeline_fixture: Pipeline) -> None: + """Test that the reconstruction loss is more than the base loss.""" + + # Create a dummy metric, so we can retrieve the stored data afterwards + class StoreValidationMetric(AbstractValidationMetric): + """Dummy metric to store the data.""" + + data: ValidationMetricData | None + + def calculate(self, data: ValidationMetricData) -> dict[str, Any]: + """Store the data.""" + self.data = data + return {} + + dummy_metric = StoreValidationMetric() + pipeline_fixture.metrics.validation_metrics.append(dummy_metric) + + # Run the validation loop + store_size: int = 1000 + pipeline_fixture.generate_activations(store_size) + pipeline_fixture.validate_sae(store_size) + + # Check the loss + assert ( + dummy_metric.data is not None + ), "Dummy metric should have stored the data from the validation loop." + assert ( + dummy_metric.data.source_model_loss_with_reconstruction + > dummy_metric.data.source_model_loss + ), "Reconstruction loss should be more than base loss." + + assert ( + dummy_metric.data.source_model_loss_with_zero_ablation + > dummy_metric.data.source_model_loss + ), "Zero ablation loss should be more than base loss." + + +class TestRunPipeline: + """Test the run_pipeline method.""" + + def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> None: + """Test that the run_pipeline method calls all the other methods.""" + pipeline_fixture.validate_sae = MagicMock(spec=Pipeline.validate_sae) # type: ignore + pipeline_fixture.save_checkpoint = MagicMock(spec=Pipeline.save_checkpoint) # type: ignore + pipeline_fixture.activation_resampler.step_resampler = MagicMock( # type: ignore + spec=ActivationResampler.step_resampler, return_value=None + ) + + store_size = 1000 + context_size = pipeline_fixture.source_dataset.context_size + train_batch_size = store_size // context_size + + total_loops = 5 + validate_expected_calls = 2 + checkpoint_expected_calls = 5 + + pipeline_fixture.run_pipeline( + train_batch_size=train_batch_size, + max_store_size=store_size, + max_activations=store_size * 5, + validation_number_activations=store_size, + validate_frequency=store_size * (total_loops // validate_expected_calls), + checkpoint_frequency=store_size * (total_loops // checkpoint_expected_calls), + ) + + # Check the number of calls + assert ( + pipeline_fixture.validate_sae.call_count == validate_expected_calls + ), f"Validate should have been called {validate_expected_calls} times." + + assert ( + pipeline_fixture.save_checkpoint.call_count == checkpoint_expected_calls + ), f"Checkpoint should have been called {checkpoint_expected_calls} times." + + assert (pipeline_fixture.activation_resampler) is not None + assert ( + pipeline_fixture.activation_resampler.step_resampler.call_count == total_loops + ), f"Resampler should have been called {total_loops} times."