Skip to content

Commit

Permalink
fix: Merge pull request #33 from jbloomAus/fix/topk-selection-purview
Browse files Browse the repository at this point in the history
Fix/topk selection purview
  • Loading branch information
jbloomAus authored Oct 24, 2024
2 parents 8235a9e + fb141ae commit afccd5a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
16 changes: 13 additions & 3 deletions sae_dashboard/feature_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jaxtyping import Float, Int
from sae_lens import SAE
from sae_lens.config import DTYPE_MAP as DTYPES
from sae_lens.sae import TopK
from torch import Tensor, nn
from tqdm.auto import tqdm

Expand Down Expand Up @@ -87,11 +88,20 @@ def get_feature_data( # type: ignore
self.encoder.device
) # make sure acts are on the correct device

# Compute feature activations from this
with FeatureMaskingContext(self.encoder, feature_indices):
feature_acts = self.encoder.encode(primary_acts).to(
# For TopK, compute all activations first, then select features
if isinstance(self.encoder.activation_fn, TopK):
# Get all features' activations
all_features_acts = self.encoder.encode(primary_acts)
# Then select only the features we're interested in
feature_acts = all_features_acts[:, :, feature_indices].to(
DTYPES[self.cfg.dtype]
)
else:
# For other activation functions, use the masking context
with FeatureMaskingContext(self.encoder, feature_indices):
feature_acts = self.encoder.encode(primary_acts).to(
DTYPES[self.cfg.dtype]
)

self.update_rolling_coefficients(
model_acts=primary_acts,
Expand Down
4 changes: 2 additions & 2 deletions sae_dashboard/neuronpedia/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import numpy as np
import torch
import wandb
import wandb.sdk
from matplotlib import colors
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_saes import load_sparsity
from sae_lens.training.activations_store import ActivationsStore
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
import wandb.sdk
from sae_dashboard.components_config import (
ActsHistogramConfig,
Column,
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_dashboards_runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def copy_output_folder(source: str, destination: str):

def find_saelens_release_from_neuronpedia_id(neuronpedia_id: str) -> str:
for release, item in directory.items():
for _, np_id in item.neuronpedia_id.items():
for _, np_id in item.neuronpedia_id.items(): # type: ignore
if np_id == neuronpedia_id:
return release
raise ValueError(f"Neuronpedia ID {neuronpedia_id} not found")
Expand Down

0 comments on commit afccd5a

Please sign in to comment.