Skip to content

Commit

Permalink
Fix demo notebook (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Nov 7, 2023
1 parent a5ad6c5 commit 29ff15a
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 46 deletions.
160 changes: 125 additions & 35 deletions demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
"metadata": {},
"outputs": [],
"source": [
"from sparse_autoencoder import TensorActivationStore, SparseAutoencoder\n",
"from sparse_autoencoder import TensorActivationStore, SparseAutoencoder, pipeline\n",
"from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset\n",
"from transformer_lens import HookedTransformer\n",
"from transformer_lens.utils import get_device\n",
"from transformers import GPT2TokenizerFast\n",
"from transformers import PreTrainedTokenizerBase\n",
"import torch\n",
"import wandb"
]
Expand All @@ -53,39 +53,43 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Source Dataset"
"### Source Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model solu-1l into HookedTransformer\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2f8dd302359141b09ba44ba4c1d519d9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
"2048"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)"
"src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n",
"src_d_mlp: int = src_model.cfg.d_mlp # type: ignore\n",
"src_d_mlp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Source Model"
"### Source Dataset"
]
},
{
Expand All @@ -97,24 +101,31 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model solu-1l into HookedTransformer\n"
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a1ce590449484e1788109c4f13a2e8bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"2048"
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
"output_type": "display_data"
}
],
"source": [
"src_model = HookedTransformer.from_pretrained(\"solu-1l\", dtype=\"float32\")\n",
"src_d_mlp = src_model.cfg.d_mlp\n",
"src_d_mlp"
"tokenizer: PreTrainedTokenizerBase = src_model.tokenizer # type: ignore\n",
"source_data = PileUncopyrightedDataset(tokenizer=tokenizer)\n",
"src_dataloader = source_data.get_dataloader(batch_size=8)"
]
},
{
Expand All @@ -130,7 +141,7 @@
"metadata": {},
"outputs": [],
"source": [
"max_items = 2_000_000\n",
"max_items = 1_000_000\n",
"store = TensorActivationStore(max_items, src_d_mlp, device)"
]
},
Expand All @@ -143,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -162,7 +173,7 @@
")"
]
},
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -197,20 +208,99 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "309fbf4a29a147ada581ba09b0cff34d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate/Train Cycles: 0it [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a26f99ac95d44bf196f1d5fe70bafbe9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5cd07ef70a1f4b4c97cd2828f4cfd745",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/alan/Documents/Repos/sparse_autoencoder/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py:251: UserWarning: The operator 'aten::sgn.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:13.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "242a91de8f694f64a7d04e93f25b95dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generate Activations: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5337eac728eb4ced9590c001e20e53ed",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Train Autoencoder: 0%| | 0/1000000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# pipeline(\n",
"# src_model=src_model,\n",
"# src_model_activation_hook_point=\"blocks.0.mlp.hook_post\",\n",
"# src_model_activation_layer=0,\n",
"# src_dataloader=src_dataloader,\n",
"# activation_store=store,\n",
"# num_activations_before_training=max_items,\n",
"# autoencoder=autoencoder,\n",
"# device=device,\n",
"# )"
"pipeline(\n",
" src_model=src_model,\n",
" src_model_activation_hook_point=\"blocks.0.mlp.hook_post\",\n",
" src_model_activation_layer=0,\n",
" src_dataloader=src_dataloader,\n",
" activation_store=store,\n",
" num_activations_before_training=max_items,\n",
" autoencoder=autoencoder,\n",
" device=device,\n",
")"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
[tool.poe.tasks.check]
help="Run all checks"
ignore_fail=false
sequence=["checklock", "format", "lint", "test", "typecheck"]
sequence=["check-lock", "format", "lint", "test", "typecheck"]

[tool.poe.tasks.format]
cmd="ruff format sparse_autoencoder"
Expand Down
3 changes: 3 additions & 0 deletions sparse_autoencoder/activation_store/base_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __getitem__(self, index: int) -> ActivationStoreItem:
"""Get an Item from the Store."""
raise NotImplementedError

def shuffle(self) -> None:
"""Optional shuffle method."""


class StoreFullError(IndexError):
"""Exception raised when the activation store is full."""
Expand Down
10 changes: 5 additions & 5 deletions sparse_autoencoder/source_data/pile_uncopyrighted.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The Pile Uncopyrighted Dataset."""
from typing import TypedDict, final

from transformers import PreTrainedTokenizerFast
from transformers import PreTrainedTokenizerBase

from sparse_autoencoder.source_data.abstract_dataset import (
PreprocessTokenizedPrompts,
Expand All @@ -26,7 +26,7 @@ class PileUncopyrightedDataset(SourceDataset[PileUncopyrightedSourceDataBatch]):
https://huggingface.co/datasets/monology/pile-uncopyrighted
"""

tokenizer: PreTrainedTokenizerFast
tokenizer: PreTrainedTokenizerBase

def preprocess(
self,
Expand Down Expand Up @@ -65,7 +65,7 @@ def preprocess(

def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
tokenizer: PreTrainedTokenizerBase,
context_size: int = 250,
buffer_size: int = 1000,
preprocess_batch_size: int = 1000,
Expand All @@ -75,8 +75,8 @@ def __init__(
"""Initialize the Pile Uncopyrighted dataset.
Example:
>>> from transformers import PreTrainedTokenizerFast
>>> tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
>>> from transformers import GPT2TokenizerFast
>>> tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
>>> data = PileUncopyrightedDataset(
... tokenizer=tokenizer
... )
Expand Down
19 changes: 19 additions & 0 deletions sparse_autoencoder/source_data/tests/test_pile_uncopyrighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ def test_tokenized_prompts_correct_size(context_size: int) -> None:
# Check the tokens are integers
for token in item["input_ids"]:
assert isinstance(token, int)


def test_dataloader_correct_size_items() -> None:
"""Test the dataloader returns the correct number & sized items."""
batch_size = 10
context_size = 250
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
data = PileUncopyrightedDataset(tokenizer=tokenizer, context_size=context_size)
dataloader = data.get_dataloader(batch_size=batch_size)

checks = 100
for item in dataloader:
checks -= 1
if checks == 0:
break

tokens = item["input_ids"]
assert tokens.shape[0] == batch_size
assert tokens.shape[1] == context_size
6 changes: 3 additions & 3 deletions sparse_autoencoder/train/generate_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate_activations(
model.add_hook(cache_name, hook)

# Get the input dimensions for logging
first_item: Int[Tensor, "batch pos"] = next(iter(dataloader))
first_item: Int[Tensor, "batch pos"] = next(iter(dataloader))["input_ids"]
batch_size: int = first_item.shape[0]
context_size: int = first_item.shape[1]
activations_per_batch: int = context_size * batch_size
Expand All @@ -78,9 +78,9 @@ def generate_activations(
leave=False,
dynamic_ncols=True,
) as progress_bar:
for input_ids in dataloader:
for batch in dataloader:
try:
input_ids = input_ids.to(device) # noqa: PLW2901
input_ids = batch["input_ids"].to(device)
model.forward(input_ids, stop_at_layer=layer + 1) # type: ignore
progress_bar.update(activations_per_batch)

Expand Down
3 changes: 1 addition & 2 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def pipeline(

# Shuffle the store if it has a shuffle method - it is often more efficient to
# create a shuffle method ourselves rather than get the DataLoader to shuffle
if hasattr(activation_store, "shuffle"):
activation_store.shuffle() # type: ignore
activation_store.shuffle()

# Create a dataloader from the store
dataloader = DataLoader(
Expand Down

0 comments on commit 29ff15a

Please sign in to comment.