Skip to content

Commit

Permalink
Add option to specify chunks in broadcast_to_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Jan 16, 2024
1 parent fa07442 commit dcf9df8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ This document explains the changes made to Iris for this release
lazy data from file. This will also speed up coordinate comparison.
(:pull:`5610`)

#. `@bouweandela`_ added the option to specify the Dask chunks of the target
array in :func:`iris.util.broadcast_to_shape`. (:pull:`5620`)

🔥 Deprecations
===============
Expand Down
27 changes: 27 additions & 0 deletions lib/iris/tests/unit/util/test_broadcast_to_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,33 @@ def test_lazy_masked(self, mocked_compute):
for j in range(4):
self.assertMaskedArrayEqual(b[i, :, j, :].compute().T, m.compute())

@mock.patch.object(dask.base, "compute", wraps=dask.base.compute)
def test_lazy_chunks(self, mocked_compute):
# chunks can be specified along with the target shape and are only used
# along new dimensions or on dimensions that have size 1 in the source
# array.
m = da.ma.masked_array(
data=[[1, 2, 3, 4, 5]],
mask=[[0, 1, 0, 0, 0]],
).rechunk((1, 2))
b = broadcast_to_shape(
m,
dim_map=(1, 2),
shape=(3, 4, 5),
chunks=(
1, # used because target is new dim
2, # used because input size 1
3, # not used because broadcast does not rechunk
),
)
mocked_compute.assert_not_called()
for i in range(3):
for j in range(4):
self.assertMaskedArrayEqual(
b[i, j, :].compute(), m[0].compute()
)
assert b.chunks == ((1, 1, 1), (2, 2), (2, 2, 1))

def test_masked_degenerate(self):
# masked arrays can have degenerate masks too
a = np.random.random([2, 3])
Expand Down
29 changes: 25 additions & 4 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import iris.exceptions


def broadcast_to_shape(array, shape, dim_map):
def broadcast_to_shape(array, shape, dim_map, chunks=None):
"""Broadcast an array to a given shape.
Each dimension of the array must correspond to a dimension in the
Expand All @@ -46,6 +46,13 @@ def broadcast_to_shape(array, shape, dim_map):
the index in *shape* which the dimension of *array* corresponds
to, so the first element of *dim_map* gives the index of *shape*
that corresponds to the first dimension of *array* etc.
chunks : :class:`tuple`, optional
If the source array is a :class:`dask.array.Array` and a value is
provided, then the result will use these chunks instead of the same
chunks as the source array. Setting chunks explicitly as part of
broadcast_to_shape is more efficient than rechunking afterwards. The
values provided here will only be used along dimensions that are new on
the result or have size 1 on the source array.
Examples
--------
Expand All @@ -68,27 +75,41 @@ def broadcast_to_shape(array, shape, dim_map):
See more at :doc:`/userguide/real_and_lazy_data`.
"""
if isinstance(array, da.Array):
if chunks is not None:
chunks = list(chunks)
for src_idx, tgt_idx in enumerate(dim_map):
# Only use the specified chunks along new dimensions or on
# dimensions that have size 1 in the source array.
if array.shape[src_idx] != 1:
chunks[tgt_idx] = array.chunks[src_idx]
broadcast = functools.partial(
da.broadcast_to, shape=shape, chunks=chunks
)
else:
broadcast = functools.partial(np.broadcast_to, shape=shape)

n_orig_dims = len(array.shape)
n_new_dims = len(shape) - n_orig_dims
array = array.reshape(array.shape + (1,) * n_new_dims)

# Get dims in required order.
array = np.moveaxis(array, range(n_orig_dims), dim_map)
new_array = np.broadcast_to(array, shape)
new_array = broadcast(array)

if ma.isMA(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = ma.getmask(array)
if mask is ma.nomask:
new_mask = ma.nomask
else:
new_mask = np.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = ma.array(new_array, mask=new_mask)

elif is_lazy_masked_data(array):
# broadcast_to strips masks so we need to handle them explicitly.
mask = da.ma.getmaskarray(array)
new_mask = da.broadcast_to(mask, shape)
new_mask = broadcast(mask)
new_array = da.ma.masked_array(new_array, new_mask)

return new_array
Expand Down

0 comments on commit dcf9df8

Please sign in to comment.