Skip to content

Commit

Permalink
fix: limit quantile to max 16 million elements
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Jan 17, 2025
1 parent 4dbde12 commit f193517
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion sae_dashboard/utils_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def decode(token_id: int | list[int]) -> str | list[str]:
**HTML_ANOMALIES_REVERSED,
}

# torch.quantile crashes if you pass a tensor with more than 16 million elements
# see: https://github.com/pytorch/pytorch/issues/64947
MAX_QUANTILE = 2**24 - 1


def process_str_tok(str_tok: str, html: bool = True) -> str:
"""
Expand Down Expand Up @@ -617,7 +621,7 @@ def create(
)

batch_quantile_data = torch.quantile(
batch.to(torch.float32),
batch[:, :MAX_QUANTILE].to(torch.float32),
quantiles_tensor.to(torch.float32),
dim=-1,
)
Expand Down

0 comments on commit f193517

Please sign in to comment.