Skip to content

Commit

Permalink
mollify ruff, improve type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
tadamcz committed Jan 14, 2025
1 parent 10d41ef commit c715bd7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 37 deletions.
1 change: 0 additions & 1 deletion src/inspect_ai/scorer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"bootstrap_std",
"std",
"stderr",
"clustered_stderr",
"mean",
"Metric",
"metric",
Expand Down
11 changes: 7 additions & 4 deletions src/inspect_ai/scorer/_metrics/std.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from logging import getLogger
from typing import cast
from typing import Optional, cast

import numpy as np
from numpy._typing import NDArray

from .._metric import (
Metric,
Expand Down Expand Up @@ -101,8 +102,10 @@ def metric(scores: list[ReducedScore]) -> float:


def hierarchical_bootstrap(
scores: list[list[float]], num_samples: int = 1000, random_state: int = None
) -> np.ndarray:
scores: list[list[float]],
num_samples: int = 1000,
random_state: Optional[int] = None,
) -> NDArray[np.float64]:
"""Efficient implementation of hierarchical bootstrap using vectorized operations.
See tests for a more readable counterpart using loops, ``readable_hierarchical_bootstrap``.
Expand Down Expand Up @@ -142,7 +145,7 @@ def hierarchical_bootstrap(
# Then average over clusters
bootstrap_means = np.mean(cluster_means, axis=1) # Shape: (num_samples,)

return bootstrap_means
return cast(NDArray[np.float64], bootstrap_means)


@metric
Expand Down
71 changes: 39 additions & 32 deletions tests/scorer/test_metric_stderr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from inspect_ai.scorer._metric import ReducedScore, Score
from inspect_ai.scorer._metrics.std import stderr, hierarchical_bootstrap
import contextlib
from typing import List, Optional

import numpy as np
from typing import List
import pytest
import contextlib

from inspect_ai.scorer._metric import ReducedScore, Score
from inspect_ai.scorer._metrics.std import hierarchical_bootstrap, stderr


@contextlib.contextmanager
def fixed_seed(seed=42):
Expand Down Expand Up @@ -38,6 +41,7 @@ def test_stderr_variance_hierarchy():
# Stderr should be larger when there's within-cluster variance
assert stderr_varying > stderr_constant


def test_stderr_constant():
"""When all values are identical, stderr should be 0"""
with fixed_seed():
Expand All @@ -50,6 +54,7 @@ def test_stderr_constant():
result = metric(scores)
assert result == pytest.approx(0)


def test_stderr_trivial_clusters():
"""Test with multiple clusters but only 1 member each"""
with fixed_seed():
Expand All @@ -61,23 +66,24 @@ def test_stderr_trivial_clusters():
metric = stderr()
result = metric(scores)
# For [0.3, 0.8, 0.4]: std ≈ 0.216, sdt/√3 ≈ 0.124
assert result == pytest.approx(0.124, rel=5/100)
assert result == pytest.approx(0.124, rel=5 / 100)


def test_stderr_single_cluster():
"""Test with only one cluster but multiple members"""
with fixed_seed():
scores = [
ReducedScore(value=0.5, children=[
Score(value=0.3),
Score(value=0.8),
Score(value=0.4)
])
ReducedScore(
value=0.5,
children=[Score(value=0.3), Score(value=0.8), Score(value=0.4)],
)
]
metric = stderr()
result = metric(scores)
# Same as above but with different hierarchy structure
# For [0.3, 0.8, 0.4]: std ≈ 0.216, sdt/√3 ≈ 0.124
assert result == pytest.approx(0.124, rel=5/100)
assert result == pytest.approx(0.124, rel=5 / 100)


def test_stderr_empty_cluster_raises():
"""Test that empty clusters raise ValueError"""
Expand All @@ -88,26 +94,29 @@ def test_stderr_empty_cluster_raises():
with pytest.raises(ValueError, match="requires non-empty clusters"):
metric(scores)


def readable_hierarchical_bootstrap(
scores: List[List[float]],
n_bootstrap: int = 1000,
random_state: int = None
scores: List[List[float]],
n_bootstrap: int = 1000,
random_state: Optional[int] = None,
) -> List[float]:
"""
Readable implementation of hierarchical bootstrap using Python loops for clarity.
"""
"""Readable implementation of hierarchical bootstrap using Python loops for clarity."""
rng = np.random.default_rng(random_state)
scores_array = np.array(scores) # Shape: (n_clusters, n_members)
bootstrap_means = []

for _ in range(n_bootstrap):
# Resample clusters using axis parameter
sampled_clusters = rng.choice(scores_array, size=len(scores_array), replace=True, axis=0)
sampled_clusters = rng.choice(
scores_array, size=len(scores_array), replace=True, axis=0
)

# For each cluster, resample its members
cluster_means = []
for cluster_scores in sampled_clusters:
resampled_members = rng.choice(cluster_scores, size=len(cluster_scores), replace=True)
resampled_members = rng.choice(
cluster_scores, size=len(cluster_scores), replace=True
)
cluster_means.append(np.mean(resampled_members))

# Calculate mean across clusters for this bootstrap sample
Expand All @@ -119,10 +128,9 @@ def readable_hierarchical_bootstrap(


def test_redable_matches_fast():
"""Test that the readable implementation of hierarchical bootstrap matches the fast
implementation."""
"""Test that the readable implementation of hierarchical bootstrap matches the fast implementation."""
scores = [
[2.0, 4.0, 8.0, 16.], # cluster 1
[2.0, 4.0, 8.0, 16.0], # cluster 1
[0.5, 0.7, 0.6, 0.8], # cluster 2
[2.1, 2.3, 2.0, 2.2], # cluster 3
]
Expand All @@ -134,25 +142,24 @@ def test_redable_matches_fast():

# Run both implementations
fast_results = hierarchical_bootstrap(
scores,
num_samples=n_samples,
random_state=random_seed
scores, num_samples=n_samples, random_state=random_seed
)

readable_results = readable_hierarchical_bootstrap(
scores,
n_bootstrap=n_samples,
random_state=random_seed
scores, n_bootstrap=n_samples, random_state=random_seed
)


# Compare mean and standard deviation
assert np.mean(fast_results) == pytest.approx(np.mean(readable_results), rel=relative_tolerance)
assert np.std(fast_results) == pytest.approx(np.std(readable_results), rel=relative_tolerance)
assert np.mean(fast_results) == pytest.approx(
np.mean(readable_results), rel=relative_tolerance
)
assert np.std(fast_results) == pytest.approx(
np.std(readable_results), rel=relative_tolerance
)

# Compare quartiles
q = [0.25, 0.5, 0.75]
fast_quantiles = np.quantile(fast_results, q)
readable_quantiles = np.quantile(readable_results, q)
for fq, rq in zip(fast_quantiles, readable_quantiles):
assert fq == pytest.approx(rq, rel=relative_tolerance)
assert fq == pytest.approx(rq, rel=relative_tolerance)

0 comments on commit c715bd7

Please sign in to comment.