Skip to content

Commit

Permalink
Merge pull request #19 from jbloomAus/fix/remove_precision_reduction
Browse files Browse the repository at this point in the history
Removed precision reduction option
  • Loading branch information
jbloomAus authored Aug 15, 2024
2 parents 3565762 + 1b6a4a9 commit a5f8df1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 195 deletions.
101 changes: 14 additions & 87 deletions sae_dashboard/utils_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,77 +467,6 @@ def pad_with_zeros(
]


def float16_quantile(
input: torch.Tensor,
q: torch.Tensor,
dim: Optional[int] = None,
keepdim: bool = False,
interpolation: str = "linear",
) -> torch.Tensor:
"""Performs the torch quantile function for float16 tensors.
Args:
input (torch.Tensor): The input tensor.
q (torch.Tensor): The quantile(s) to compute, which must be between 0 and 1.
dim: The dimension(s) to reduce.
keepdim: Whether to keep the same as the original.
interpolation: The interpolation method to use when the desired quantile lies between two data points i and j.
Returns:
torch.Tensor: The computed quantile(s).
"""
print("Using float16 quantile calculation")
if dim is None:
input = input.flatten()
dim = 0

# Ensure q is a 1D tensor
q = q.squeeze()

# Move dim to the end for easier processing
input = input.transpose(dim, -1)

sorted_input, _ = torch.sort(input, dim=-1)

quantile_indices = (
(q * (input.shape[-1] - 1))
.to(input.dtype)
.unsqueeze(0)
.expand(input.shape[:-1] + (-1,))
)
lower_indices = torch.floor(quantile_indices).long()
upper_indices = torch.ceil(quantile_indices).long()
fractional_part = quantile_indices - lower_indices.to(input.dtype)

if interpolation == "linear":
lower_values = torch.gather(sorted_input, -1, lower_indices)
upper_values = torch.gather(sorted_input, -1, upper_indices)
result = lower_values * (1 - fractional_part) + upper_values * fractional_part
elif interpolation == "lower":
result = torch.gather(sorted_input, -1, lower_indices)
elif interpolation == "higher":
result = torch.gather(sorted_input, -1, upper_indices)
elif interpolation == "nearest":
nearest_indices = torch.where(
fractional_part < 0.5, lower_indices, upper_indices
)
result = torch.gather(sorted_input, -1, nearest_indices)
elif interpolation == "midpoint":
lower_values = torch.gather(sorted_input, -1, lower_indices)
upper_values = torch.gather(sorted_input, -1, upper_indices)
result = (lower_values + upper_values) / 2
else:
raise ValueError(f"Invalid interpolation method: {interpolation}")

# Move dim back to its original position
result = result.transpose(0, dim)

if not keepdim:
result = result.squeeze(dim)

return result


@dataclass_json
@dataclass
class FeatureStatistics:
Expand Down Expand Up @@ -583,10 +512,18 @@ def create(
ranges_and_precisions: list[
tuple[list[float], int]
] = ASYMMETRIC_RANGES_AND_PRECISIONS,
reduce_precision: bool = True,
batch_size: Optional[int] = None,
) -> "FeatureStatistics":
"""Calculates various statistics for a tensor of activations.
Args:
data: A tensor of activations; should be shape (n_features, n_samples (n_prompts * n_prompt_tokens)).
ranges_and_precisions: A list of tuples of the form (range, precision).
batch_size: The feature batch size to use for processing the acts. Reduce this if you encounter OOM errors.
Returns:
A FeatureStatistics object.
"""
if not batch_size:
batch_size = 0 if data is None else data.shape[0]

Expand Down Expand Up @@ -625,21 +562,11 @@ def create(
quantiles, dtype=batch.dtype, device=batch.device
)

if batch.dtype in [torch.float16, torch.bfloat16]:
batch_quantile_data = float16_quantile(batch, quantiles_tensor, dim=-1)
else:
if reduce_precision:
batch_quantile_data = float16_quantile(
batch.to(torch.float16),
quantiles_tensor.to(torch.float16),
dim=-1,
)
else:
batch_quantile_data = torch.quantile(
batch.to(torch.float32),
quantiles_tensor.to(torch.float32),
dim=-1,
)
batch_quantile_data = torch.quantile(
batch.to(torch.float32),
quantiles_tensor.to(torch.float32),
dim=-1,
)

quantile_data.extend(batch_quantile_data.T.tolist())

Expand Down
108 changes: 0 additions & 108 deletions tests/unit/test_util_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,114 +256,6 @@ def test_feature_statistics_update():
assert len(stats1.quantile_data) == 4


def test_feature_statistics_quantile_accuracy():
# Create sample data
torch.manual_seed(0) # for reproducibility
data = torch.rand(1000, 100) # 1000 features, 100 data points each

# Test for float32 and float16
for dtype in [torch.float32, torch.float16]:
data_typed = data.to(dtype)

# Create FeatureStatistics object
feature_stats = FeatureStatistics.create(data_typed)

# Calculate quantiles using the same method as in FeatureStatistics.create
quantiles = []
for r, p in ASYMMETRIC_RANGES_AND_PRECISIONS:
start, end = r
step = 10**-p
quantiles.extend(np.arange(start, end - 0.5 * step, step))

quantiles_tensor = torch.tensor(quantiles, dtype=data_typed.dtype).to(
data.device
)
expected_quantile_data = torch.quantile(
data.to(torch.float32), quantiles_tensor.to(torch.float32), dim=-1
)
expected_quantile_data = expected_quantile_data.T.tolist()
expected_quantile_data = [
[round(q, 6) for q in qd] for qd in expected_quantile_data
]
for i, qd in enumerate(expected_quantile_data):
first_nonzero = next(
(i for i, x in enumerate(qd) if abs(x) > 1e-6), len(qd)
)
expected_quantile_data[i] = qd[first_nonzero:]

# Compare results
for i, (expected, actual) in enumerate(
zip(expected_quantile_data, feature_stats.quantile_data)
):

print(f"Dtype: {dtype}, Feature {i}")
print(f"Expected: {expected[-5:]}...")
print(f"Actual: {actual[-5:]}...")
print(f"Expected length: {len(expected)}, Actual length: {len(actual)}")

assert len(expected) == len(
actual
), f"Length mismatch for feature {i}, expected {len(expected)}, got {len(actual)}"
np.testing.assert_allclose(
actual,
expected,
rtol=1e-2,
atol=1e-2,
err_msg=f"Mismatch for feature {i} with dtype {dtype}",
)

print(f"All quantiles match for dtype {dtype}")


@pytest.mark.parametrize("n_features,n_tokens", [(100, 1000), (50, 500)])
def test_feature_statistics_precision_reduction(n_features: int, n_tokens: int):
# Create a random 2D float32 tensor
torch.manual_seed(42) # for reproducibility
data = torch.randn(n_features, n_tokens, dtype=torch.float32)

# Create FeatureStatistics without reducing precision
stats_full = FeatureStatistics.create(data, reduce_precision=False)

# Create FeatureStatistics with reduced precision
stats_reduced = FeatureStatistics.create(data, reduce_precision=True)

# Compare max values
assert np.allclose(
stats_full.max, stats_reduced.max, atol=1e-2
), "Max values do not match within tolerance"

# Compare fraction of non-zero values
assert np.allclose(
stats_full.frac_nonzero, stats_reduced.frac_nonzero, atol=1e-2
), "Fraction of non-zero values do not match within tolerance"

# Compare quantiles
assert stats_full.quantiles == stats_reduced.quantiles, "Quantiles do not match"

# Compare quantile data
assert len(stats_full.quantile_data) == len(
stats_reduced.quantile_data
), "Quantile data lengths do not match"
for full_qd, reduced_qd in zip(
stats_full.quantile_data, stats_reduced.quantile_data
):
assert len(full_qd) == len(reduced_qd), "Quantile data sub-lengths do not match"
if not np.allclose(
full_qd, reduced_qd, rtol=1e-1
): # , "Quantile data values do not match within tolerance"
print(
f"Mean difference: {np.mean(np.abs(np.array(full_qd) - np.array(reduced_qd)))}"
)

# Compare ranges_and_precisions
assert (
stats_full.ranges_and_precisions == stats_reduced.ranges_and_precisions
), "Ranges and precisions do not match"

print(f"Test completed for n_features: {n_features}, n_tokens: {n_tokens}")
print("Full precision and reduced precision results match within tolerance.")


# def test_feature_statistics_benchmark(large_precision_data):
# # Check if CUDA is available
# if not torch.cuda.is_available():
Expand Down

0 comments on commit a5f8df1

Please sign in to comment.