Skip to content

Commit

Permalink
Add a model reconstruction validation metric (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 26, 2023
1 parent b2c821f commit a8fe5d0
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 64 deletions.
107 changes: 53 additions & 54 deletions docs/content/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,27 @@
" # and we have found that 4x is a good starting point.\n",
" \"expansion_factor\": 4,\n",
" # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).\n",
" \"l1_coefficient\": 3e-4,\n",
" \"l1_coefficient\": 1e-3,\n",
" # Adam parameters (set to the default ones here)\n",
" \"lr\": 1e-4,\n",
" \"lr\": 3e-4,\n",
" \"adam_beta_1\": 0.9,\n",
" \"adam_beta_2\": 0.999,\n",
" \"adam_epsilon\": 1e-8,\n",
" \"adam_weight_decay\": 0.0,\n",
" # Batch sizes\n",
" \"train_batch_size\": 4096,\n",
" \"context_size\": 128,\n",
" # Source model hook point\n",
" \"source_model_name\": \"gelu-2l\",\n",
" \"source_model_dtype\": \"float32\",\n",
" \"source_model_hook_point\": \"blocks.0.hook_mlp_out\",\n",
" \"source_model_hook_point_layer\": 0,\n",
" # Train pipeline parameters\n",
" \"max_store_size\": 384 * 4096 * 2,\n",
" \"max_activations\": 2_000_000_000,\n",
" \"resample_frequency\": 122_880_000,\n",
" \"checkpoint_frequency\": 100_000_000,\n",
" \"validation_frequency\": 384 * 4096 * 2 * 100, # Every 100 generations\n",
"}"
]
},
Expand All @@ -141,7 +152,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -157,22 +168,22 @@
"'Source: gelu-2l, Hook: blocks.0.hook_mlp_out, Features: 512'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Source model setup with TransformerLens\n",
"src_model_name = \"gelu-2l\"\n",
"src_model = HookedTransformer.from_pretrained(src_model_name, dtype=\"float32\")\n",
"src_model = HookedTransformer.from_pretrained(\n",
" str(hyperparameters[\"source_model_name\"]), dtype=str(hyperparameters[\"source_model_dtype\"])\n",
")\n",
"\n",
"# Details about the activations we'll train the sparse autoencoder on\n",
"src_model_activation_hook_point = \"blocks.0.hook_mlp_out\"\n",
"src_model_activation_layer = 0\n",
"autoencoder_input_dim: int = src_model.cfg.d_model # type: ignore (TransformerLens typing is currently broken)\n",
"\n",
"f\"Source: {src_model_name}, Hook: {src_model_activation_hook_point}, \\\n",
"f\"Source: {hyperparameters['source_model_name']}, \\\n",
" Hook: {hyperparameters['source_model_hook_point']}, \\\n",
" Features: {autoencoder_input_dim}\""
]
},
Expand All @@ -199,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -216,7 +227,7 @@
")"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -244,19 +255,19 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LossReducer(\n",
" (0): LearnedActivationsL1Loss(l1_coefficient=0.0003)\n",
" (0): LearnedActivationsL1Loss(l1_coefficient=0.001)\n",
" (1): L2ReconstructionLoss()\n",
")"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -265,7 +276,7 @@
"# We use a loss reducer, which simply adds up the losses from the underlying loss functions.\n",
"loss = LossReducer(\n",
" LearnedActivationsL1Loss(\n",
" l1_coefficient=hyperparameters[\"l1_coefficient\"],\n",
" l1_coefficient=float(hyperparameters[\"l1_coefficient\"]),\n",
" ),\n",
" L2ReconstructionLoss(),\n",
")\n",
Expand All @@ -274,7 +285,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -289,13 +300,13 @@
" eps: 1e-08\n",
" foreach: None\n",
" fused: None\n",
" lr: 0.0001\n",
" lr: 0.0003\n",
" maximize: False\n",
" weight_decay: 0.0\n",
")"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -304,10 +315,10 @@
"optimizer = AdamWithReset(\n",
" params=autoencoder.parameters(),\n",
" named_parameters=autoencoder.named_parameters(),\n",
" lr=hyperparameters[\"lr\"],\n",
" betas=(hyperparameters[\"adam_beta_1\"], hyperparameters[\"adam_beta_2\"]),\n",
" eps=hyperparameters[\"adam_epsilon\"],\n",
" weight_decay=hyperparameters[\"adam_weight_decay\"],\n",
" lr=float(hyperparameters[\"lr\"]),\n",
" betas=(float(hyperparameters[\"adam_beta_1\"]), float(hyperparameters[\"adam_beta_2\"])),\n",
" eps=float(hyperparameters[\"adam_epsilon\"]),\n",
" weight_decay=float(hyperparameters[\"adam_weight_decay\"]),\n",
")\n",
"optimizer"
]
Expand All @@ -321,7 +332,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -345,13 +356,13 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4bdf3ebe364243bd8f881933e56c997d",
"model_id": "75e636ebb9e04b279c7216c74496538d",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -390,7 +401,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -400,7 +411,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -426,7 +437,7 @@
{
"data": {
"text/html": [
"Run data is saved locally in <code>.cache/wandb/run-20231126_122954-xsruek7y</code>"
"Run data is saved locally in <code>.cache/wandb/run-20231126_184500-2fnpg8zi</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -438,7 +449,7 @@
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">vivid-totem-95</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
"Syncing run <strong><a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">prime-star-105</a></strong> to <a href='https://wandb.ai/alan-cooney/sparse-autoencoder' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -462,7 +473,7 @@
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y</a>"
" View run at <a href='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi' target=\"_blank\">https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -474,13 +485,13 @@
{
"data": {
"text/html": [
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/xsruek7y?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/alan-cooney/sparse-autoencoder/runs/2fnpg8zi?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
],
"text/plain": [
"<wandb.sdk.wandb_run.Run at 0x2ff1cbcd0>"
"<wandb.sdk.wandb_run.Run at 0x3154cec10>"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -496,13 +507,13 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e1e6fa019f524f3da19708a4eda9b349",
"model_id": "1322f5e5dd5c4507a6eca9aa1f010882",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -526,9 +537,9 @@
"pipeline = Pipeline(\n",
" activation_resampler=activation_resampler,\n",
" autoencoder=autoencoder,\n",
" cache_name=src_model_activation_hook_point,\n",
" cache_name=str(hyperparameters[\"source_model_hook_point\"]),\n",
" checkpoint_directory=checkpoint_path,\n",
" layer=src_model_activation_layer,\n",
" layer=int(hyperparameters[\"source_model_hook_point_layer\"]),\n",
" loss=loss,\n",
" optimizer=optimizer,\n",
" source_data_batch_size=6,\n",
Expand All @@ -538,11 +549,11 @@
"\n",
"pipeline.run_pipeline(\n",
" train_batch_size=int(hyperparameters[\"train_batch_size\"]),\n",
" max_store_size=384 * 4096 * 2,\n",
" # Sizes for demo purposes (you probably want to scale these by 10x)\n",
" max_activations=2_000_000_000,\n",
" resample_frequency=122_880_000,\n",
" checkpoint_frequency=100_000_000,\n",
" max_store_size=int(hyperparameters[\"max_store_size\"]),\n",
" max_activations=int(hyperparameters[\"max_activations\"]),\n",
" resample_frequency=int(hyperparameters[\"resample_frequency\"]),\n",
" checkpoint_frequency=int(hyperparameters[\"checkpoint_frequency\"]),\n",
" validate_frequency=int(hyperparameters[\"validation_frequency\"]),\n",
")"
]
},
Expand All @@ -554,18 +565,6 @@
"source": [
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Advice\n",
"\n",
"-- Unfinished --\n",
"\n",
"- Check recovery loss is low while sparsity is low as well (<20 L1) usually.\n",
"- Can't be sure features are useful until you dig into them more. "
]
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions sparse_autoencoder/metrics/metrics_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sparse_autoencoder.metrics.train.capacity import CapacityMetric
from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric
from sparse_autoencoder.metrics.validate.abstract_validate_metric import AbstractValidationMetric
from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore


@dataclass
Expand All @@ -33,5 +34,6 @@ class MetricsContainer:
default_metrics = MetricsContainer(
train_metrics=[TrainBatchFeatureDensityMetric(), CapacityMetric()],
resample_metrics=[NeuronActivityMetric()],
validation_metrics=[ModelReconstructionScore()],
)
"""Default metrics container."""
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
from dataclasses import dataclass
from typing import Any

from sparse_autoencoder.tensor_types import ValidationStatistics


@dataclass
class ValidationMetricData:
"""Validation metric data."""

source_model_loss: float
source_model_loss: ValidationStatistics

source_model_loss_with_reconstruction: ValidationStatistics

autoencoder_loss: float
source_model_loss_with_zero_ablation: ValidationStatistics


class AbstractValidationMetric(ABC):
Expand Down
Loading

0 comments on commit a8fe5d0

Please sign in to comment.