Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jan 8, 2025
1 parent ab75b77 commit 99cbac5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
13 changes: 5 additions & 8 deletions sup3r/utilities/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ def get_level_masks(cls, lev_array, level):
Parameters
----------
var_array : Union[np.ndarray, da.core.Array]
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
lev_array : Union[np.ndarray, da.core.Array]
Height or pressure values for the corresponding entries in
var_array, in the same shape as var_array. If this is height and
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_level_masks(cls, lev_array, level):
da.arange(lev_array.shape[-1]), lev_array.shape
)
mask1 = lev_indices == argmin1
lev_diff = da.where(mask1, np.inf, lev_diff)
lev_diff = da.abs(da.ma.masked_array(lev_array, mask1) - level)
argmin2 = da.argmin(lev_diff, axis=-1, keepdims=True)
mask2 = lev_indices == argmin2
return mask1, mask2
Expand All @@ -61,7 +58,7 @@ def _lin_interp(cls, lev_samps, var_samps, level):
"""Linearly interpolate between levels."""
diff = da.map_blocks(lambda x, y: x - y, lev_samps[1], lev_samps[0])
alpha = da.where(
diff < 1e-3,
diff == 0,
0,
da.map_blocks(lambda x, y: x / y, (level - lev_samps[0]), diff),
)
Expand Down Expand Up @@ -109,16 +106,16 @@ def interp_to_level(
Parameters
----------
var_array : xr.DataArray
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
lev_array : xr.DataArray
Height or pressure values for the corresponding entries in
var_array, in the same shape as var_array. If this is height and
the requested levels are hub heights above surface, lev_array
should be the geopotential height corresponding to every var_array
index relative to the surface elevation (subtract the elevation at
the surface from the geopotential height)
var_array : xr.DataArray
Array of variable data, for example u-wind in a 4D array of shape
(lat, lon, time, level)
level : float
level or levels to interpolate to (e.g. final desired hub height
above surface elevation)
Expand Down
5 changes: 4 additions & 1 deletion tests/data_handlers/test_dh_nc_cc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_reload_cache():
target=target,
shape=(20, 20),
cache_kwargs=cache_kwargs,
interp_kwargs={'include_single_levels': True}
)

# reload from cache
Expand All @@ -80,7 +81,9 @@ def test_reload_cache():
cache_kwargs=cache_kwargs,
)
assert all(f in cached for f in features)
assert np.array_equal(handler.as_array(), cached.as_array())
harr = handler.as_array().compute()
carr = cached.as_array().compute()
assert np.array_equal(harr, carr)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 99cbac5

Please sign in to comment.