Skip to content

Commit

Permalink
Add categorisation to all wandb metrics (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 4, 2023
1 parent e681b29 commit 15378ab
Show file tree
Hide file tree
Showing 13 changed files with 25 additions and 21 deletions.
2 changes: 1 addition & 1 deletion sparse_autoencoder/loss/abstract_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def batch_scalar_loss_with_log(
)

# Add in the current loss module's metric
log_name = self.log_name()
log_name = "train/loss/" + self.log_name()
metrics[log_name] = current_module_loss.detach().cpu().item()

return current_module_loss, metrics
Expand Down
2 changes: 1 addition & 1 deletion sparse_autoencoder/loss/decoded_activations_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class L2ReconstructionLoss(AbstractLoss):
>>> unused_activations = torch.zeros_like(input_activations)
>>> # Outputs both loss and metrics to log
>>> loss(input_activations, unused_activations, output_activations)
(tensor(5.5000), {'l2_reconstruction_loss': 5.5})
(tensor(5.5000), {'train/loss/l2_reconstruction_loss': 5.5})
"""

_reduction: LossReductionType
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/loss/learned_activations_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def batch_scalar_loss_with_log(
batch_scalar_loss_penalty = absolute_loss_penalty.sum().squeeze()

metrics = {
"learned_activations_l1_loss": batch_scalar_loss.item(),
self.log_name(): batch_scalar_loss_penalty.item(),
"train/loss/" + "learned_activations_l1_loss": batch_scalar_loss.item(),
"train/loss/" + self.log_name(): batch_scalar_loss_penalty.item(),
}

return batch_scalar_loss_penalty, metrics
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/loss/tests/test_abstract_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def test_batch_scalar_loss_with_log(dummy_loss: DummyLoss) -> None:
source_activations, learned_activations, decoded_activations
)
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
assert log["dummy"] == expected
assert log["train/loss/dummy"] == expected


def test_call_method(dummy_loss: DummyLoss) -> None:
"""Test the call method."""
source_activations = learned_activations = decoded_activations = torch.ones((1, 3))
_loss, log = dummy_loss(source_activations, learned_activations, decoded_activations)
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
assert log["dummy"] == expected
assert log["train/loss/dummy"] == expected
2 changes: 1 addition & 1 deletion sparse_autoencoder/metrics/train/capacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]:
train_batch_capacities_histogram = self.wandb_capacities_histogram(train_batch_capacities)

return {
"train_batch_capacities_histogram": train_batch_capacities_histogram,
"train/batch_capacities_histogram": train_batch_capacities_histogram,
}
2 changes: 1 addition & 1 deletion sparse_autoencoder/metrics/train/feature_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]:
)

return {
"train_batch_feature_density_histogram": train_batch_feature_density_histogram,
"train/batch_feature_density_histogram": train_batch_feature_density_histogram,
}
2 changes: 1 addition & 1 deletion sparse_autoencoder/metrics/train/l0_norm_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def calculate(self, data: TrainMetricData) -> dict[str, float]:
batch_size = data.learned_activations.size(0)
n_non_zero_activations = torch.count_nonzero(data.learned_activations)
batch_average = n_non_zero_activations / batch_size
return {"learned_activations_l0_norm": batch_average.item()}
return {"train/learned_activations_l0_norm": batch_average.item()}
2 changes: 1 addition & 1 deletion sparse_autoencoder/metrics/train/tests/test_capacities.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def test_calculate_returns_histogram() -> None:
decoded_activations=activations,
)
)
assert "train_batch_capacities_histogram" in res
assert "train/batch_capacities_histogram" in res
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def test_calculate_aggregates() -> None:
)

# Check both metrics are in the result
assert "train_batch_feature_density_histogram" in res
assert "train/batch_feature_density_histogram" in res
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def test_l0_norm_metric() -> None:
)
log = l0_norm_metric.calculate(data)
expected = 3 / 2
assert log["learned_activations_l0_norm"] == expected
assert log["train/learned_activations_l0_norm"] == expected
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def calculate(self, data: ValidationMetricData) -> dict[str, Any]:
... )
>>> metric = ModelReconstructionScore()
>>> result = metric.calculate(data)
>>> round(result['model_reconstruction_score'], 3)
>>> round(result['validate/model_reconstruction_score'], 3)
0.667
Args:
Expand Down Expand Up @@ -78,8 +78,8 @@ def calculate(self, data: ValidationMetricData) -> dict[str, Any]:
)

return {
"validation_baseline_loss": validation_baseline_loss,
"validation_loss_with_reconstruction": validation_loss_with_reconstruction,
"validation_loss_with_zero_ablation": validation_loss_with_zero_ablation,
"model_reconstruction_score": model_reconstruction_score,
"validate/baseline_loss": validation_baseline_loss,
"validate/loss_with_reconstruction": validation_loss_with_reconstruction,
"validate/loss_with_zero_ablation": validation_loss_with_zero_ablation,
"validate/model_reconstruction_score": model_reconstruction_score,
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def test_model_reconstruction_score_various_data(
"""
metric = ModelReconstructionScore()
result = metric.calculate(data)
assert round(result["model_reconstruction_score"], 2) == expected_score
assert round(result["validate/model_reconstruction_score"], 2) == expected_score
10 changes: 7 additions & 3 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial
from pathlib import Path
from typing import final
from urllib.parse import quote_plus

from jaxtyping import Int
import torch
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__( # noqa: PLR0913
optimizer: AbstractOptimizerWithReset,
source_dataset: SourceDataset,
source_model: HookedTransformer,
run_name: str = "sparse_autoencoder",
checkpoint_directory: Path | None = None,
log_frequency: int = 100,
metrics: MetricsContainer = default_metrics,
Expand All @@ -104,6 +106,7 @@ def __init__( # noqa: PLR0913
optimizer: Optimizer to use.
source_dataset: Source dataset to get data from.
source_model: Source model to get activations from.
run_name: Name of the run for saving checkpoints.
checkpoint_directory: Directory to save checkpoints to.
log_frequency: Frequency at which to log metrics (in steps)
metrics: Metrics to use.
Expand All @@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.run_name = run_name
self.source_data_batch_size = source_data_batch_size
self.source_dataset = source_dataset
self.source_model = source_model
Expand Down Expand Up @@ -331,8 +335,10 @@ def validate_sae(self, validation_number_activations: int) -> None:
def save_checkpoint(self) -> None:
"""Save the model as a checkpoint."""
if self.checkpoint_directory:
run_name_file_system_safe = quote_plus(self.run_name)
file_path: Path = (
self.checkpoint_directory / f"sae_state_dict-{self.total_training_steps}.pt"
self.checkpoint_directory
/ f"{run_name_file_system_safe}-{self.total_training_steps}.pt"
)
torch.save(self.autoencoder.state_dict(), file_path)

Expand Down Expand Up @@ -383,8 +389,6 @@ def run_pipeline(
last_validated += num_activation_vectors_in_store
last_checkpoint += num_activation_vectors_in_store
total_activations += num_activation_vectors_in_store
if wandb.run is not None:
wandb.log({"activations_generated": total_activations}, commit=False)

# Train
progress_bar.set_postfix({"stage": "train"})
Expand Down

0 comments on commit 15378ab

Please sign in to comment.