Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling StopIteration error while generate activations #200

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
"""Default pipeline."""
from collections.abc import Iterator
from functools import partial

import itertools
import logging
from pathlib import Path
from tempfile import gettempdir
from typing import TYPE_CHECKING, final

from jaxtyping import Float, Int
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from pydantic import NonNegativeInt, PositiveInt, validate_call
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
import wandb

from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.lightning import LitSparseAutoencoder
from sparse_autoencoder.metrics.validate.reconstruction_score import ReconstructionScoreMetric
from sparse_autoencoder.metrics.wrappers.classwise import ClasswiseWrapperWithMean
from sparse_autoencoder.source_data.abstract_dataset import SourceDataset, TorchTokenizedPrompts
from sparse_autoencoder.source_model.replace_activations_hook import replace_activations_hook
from sparse_autoencoder.source_model.store_activations_hook import store_activations_hook
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:

Check failure on line 34 in sparse_autoencoder/train/pipeline.py

View workflow job for this annotation

GitHub Actions / Checks (3.11)

Ruff (I001)

sparse_autoencoder/train/pipeline.py:2:1: I001 Import block is un-sorted or un-formatted
from sparse_autoencoder.tensor_types import Axis


Expand Down Expand Up @@ -139,7 +141,7 @@
source_dataloader = source_dataset.get_dataloader(
source_data_batch_size, num_workers=num_workers_data_loading
)
self.source_data = iter(source_dataloader)
self.source_data = itertools.cycle(source_dataloader)

@validate_call
def generate_activations(self, store_size: PositiveInt) -> TensorActivationStore:
Expand Down
Loading