Skip to content

Commit

Permalink
fix: multi-gpu-tlens
Browse files Browse the repository at this point in the history
fix: handle multiple tlens devices
  • Loading branch information
jbloomAus authored Aug 27, 2024
2 parents 368b35f + ba5368f commit ed1e967
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions sae_dashboard/feature_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def get_feature_data( # type: ignore
model_activation_dict = self.get_model_acts(i, minibatch)
primary_acts = model_activation_dict[
self.model.activation_config.primary_hook_point
]
].to(
self.encoder.device
) # make sure acts are on the correct device

# Compute feature activations from this
with FeatureMaskingContext(self.encoder, feature_indices):
Expand Down Expand Up @@ -150,11 +152,13 @@ def get_model_acts(
activation_dict = load_tensor_dict_torch(cache_path, self.cfg.device)
else:
activation_dict = self.model.forward(
minibatch_tokens, return_logits=False
minibatch_tokens.to("cpu"), return_logits=False
)
save_tensor_dict_torch(activation_dict, cache_path)
else:
activation_dict = self.model.forward(minibatch_tokens, return_logits=False)
activation_dict = self.model.forward(
minibatch_tokens.to("cpu"), return_logits=False
)

return activation_dict

Expand Down
2 changes: 1 addition & 1 deletion sae_dashboard/sae_vis_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def run(
dtype=tokens.dtype,
device=tokens.device,
),
)
).to(feat_acts.device)
nonzero_feat_acts = feat_acts[(feat_acts > 0) & (ignore_tokens_mask)]
frac_nonzero = nonzero_feat_acts.numel() / feat_acts.numel()
feature_data_dict[feat].acts_histogram_data = (
Expand Down

0 comments on commit ed1e967

Please sign in to comment.