Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transition from DeviceArray to jax.Array #410

Merged
merged 18 commits into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions docs/source/include/blockarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ BlockArray

The class :class:`.BlockArray` provides a way to combine arrays of
different shapes into a single object for use with other SCICO classes.
A :class:`.BlockArray` consists of a list of :obj:`~jax.numpy.DeviceArray`
objects, which we refer to as blocks. A :class:`.BlockArray` differs from
a list in that, whenever possible, :class:`.BlockArray` properties and
methods (including unary and binary operators like +, -, \*, ...)
automatically map along the blocks, returning another :class:`.BlockArray`
or tuple as appropriate. For example,
A :class:`.BlockArray` consists of a list of :class:`jax.Array` objects,
which we refer to as blocks. A :class:`.BlockArray` differs from a list in
that, whenever possible, :class:`.BlockArray` properties and methods
(including unary and binary operators like +, -, \*, ...) automatically
map along the blocks, returning another :class:`.BlockArray` or tuple as
appropriate. For example,

::

Expand Down Expand Up @@ -108,22 +108,20 @@ Note that:
Motivating Example
------------------

Consider a two-dimensional array :math:`\mb{x} \in \mbb{R}^{n \times m}`.
The discrete differences of a two-dimensional array, :math:`\mb{x} \in
\mbb{R}^{n \times m}`, in the horizontal and vertical directions can
be represented by the arrays :math:`\mb{x}_h \in \mbb{R}^{n \times
(m-1)}` and :math:`\mb{x}_v \in \mbb{R}^{(n-1) \times m}`
respectively. While it is usually useful to consider the output of a
difference operator as a single entity, we cannot combine these two
arrays into a single array since they have different shapes. We could
vectorize each array and concatenate the resulting vectors, leading to
:math:`\mb{\bar{x}} \in \mbb{R}^{n(m-1) + m(n-1)}`, which can be
stored as a one-dimensional array, but this makes it hard to access
the individual components :math:`\mb{x}_h` and :math:`\mb{x}_v`.

We compute the discrete differences of :math:`\mb{x}` in the horizontal
and vertical directions, generating two new arrays: :math:`\mb{x}_h \in
\mbb{R}^{n \times (m-1)}` and :math:`\mb{x}_v \in \mbb{R}^{(n-1)
\times m}`.

As these arrays are of different shapes, we cannot combine them into a
single :class:`~numpy.ndarray`. Instead, we might vectorize each array and concatenate
the resulting vectors, leading to :math:`\mb{\bar{x}} \in
\mbb{R}^{n(m-1) + m(n-1)}`, which can be stored as a one-dimensional
:class:`~numpy.ndarray`. Unfortunately, this makes it hard to access the individual
components :math:`\mb{x}_h` and :math:`\mb{x}_v`.

Instead, we can form a :class:`.BlockArray`: :math:`\mb{x}_B =
[\mb{x}_h, \mb{x}_v]`
Instead, we can construct a :class:`.BlockArray`, :math:`\mb{x}_B =
[\mb{x}_h, \mb{x}_v]`:


::
Expand Down Expand Up @@ -151,7 +149,7 @@ Constructing a BlockArray
-------------------------

The recommended way to construct a :class:`.BlockArray` is by using the
`snp.blockarray` function.
:func:`snp.blockarray` function.

::

Expand All @@ -167,8 +165,8 @@ The recommended way to construct a :class:`.BlockArray` is by using the
2

While :func:`.snp.blockarray` will accept either :class:`~numpy.ndarray`\ s or
:obj:`~jax.numpy.DeviceArray`\ s as input, :class:`~numpy.ndarray`\ s
will be converted to :obj:`~jax.Array`\ s.
:class:`~jax.Array`\ s as input, :class:`~numpy.ndarray`\ s will be converted to
:class:`~jax.Array`\ s.


Operating on a BlockArray
Expand Down
27 changes: 13 additions & 14 deletions docs/source/notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,16 @@ to enforce a particular derivative convention at a point.
JAX Arrays
==========

JAX utilizes a new array type
:class:`~jaxlib.xla_extension.DeviceArray`, which is similar to NumPy
:class:`~numpy.ndarray`, but can be backed by CPU, GPU, or TPU memory
and are immutable.
JAX utilizes a new array type :class:`~jax.Array`, which is similar to
NumPy :class:`~numpy.ndarray`, but can be backed by CPU, GPU, or TPU
memory and is immutable.


DeviceArrays and NumPy Arrays
-----------------------------
JAX and NumPy Arrays
--------------------

SCICO and JAX functions can be applied directly to NumPy arrays
without explicit conversion to DeviceArrays, but this is not
without explicit conversion to JAX arrays, but this is not
recommended, as it can result in repeated data transfers from the CPU
to GPU. Consider this toy example on a system with a GPU present:

Expand All @@ -247,28 +246,28 @@ to GPU. Consider this toy example on a system with a GPU present:


The unnecessary transfer can be avoided by first converting ``A`` and ``x`` to
DeviceArrays:
JAX arrays:

::

x = np.random.randn(8) # Array on host
A = np.random.randn(8, 8) # Array on host
x = jax.device_put(x) # Transfer to GPU
x = np.random.randn(8) # array on host
A = np.random.randn(8, 8) # array on host
x = jax.device_put(x) # transfer to GPU
A = jax.device_put(A)
y = snp.dot(A, x) # no transfer needed
z = y + x # no transfer needed


We recommend that input data be converted to DeviceArray via
We recommend that input data be converted to JAX arrays via
``jax.device_put`` before calling any SCICO optimizers.

On a multi-GPU system, ``jax.device_put`` can place data on a specific
GPU. See the `JAX notes on data placement
<https://jax.readthedocs.io/en/latest/faq.html?highlight=data%20placement#controlling-data-and-computation-placement-on-devices>`_.


DeviceArrays are Immutable
--------------------------
JAX Arrays are Immutable
------------------------

Unlike standard NumPy arrays, JAX arrays are immutable: once they have
been created, they cannot be changed. This prohibits in-place updating
Expand Down
18 changes: 9 additions & 9 deletions examples/scripts/denoise_tv_pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, operator, plot
from scico.numpy import BlockArray
from scico.numpy import Array, BlockArray
from scico.numpy.util import ensure_on_device
from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize
from scico.typing import JaxArray
from scico.util import device_info

"""
Expand Down Expand Up @@ -85,7 +84,7 @@
class DualTVLoss(loss.Loss):
def __init__(
self,
y: Union[JaxArray, BlockArray],
y: Union[Array, BlockArray],
A: Optional[Union[Callable, operator.Operator]] = None,
lmbda: float = 0.5,
):
Expand All @@ -94,7 +93,7 @@ def __init__(
super().__init__(y=y, A=A, scale=1.0)
self.lmbda = lmbda

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
def __call__(self, x: Union[Array, BlockArray]) -> float:

xint = self.y - self.lmbda * self.A(x)
return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint)
Expand All @@ -112,10 +111,10 @@ class IsoProjector(functional.Functional):
has_eval = True
has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
def __call__(self, x: Union[Array, BlockArray]) -> float:
return 0.0

def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
def prox(self, v: Array, lam: float, **kwargs) -> Array:
norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0))

x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp)
Expand Down Expand Up @@ -165,10 +164,10 @@ class AnisoProjector(functional.Functional):
has_eval = True
has_prox = True

def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
def __call__(self, x: Union[Array, BlockArray]) -> float:
return 0.0

def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
def prox(self, v: Array, lam: float, **kwargs) -> Array:

return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v))

Expand All @@ -194,6 +193,7 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
)

# Run the solver.
print()
x = solver.solve()
# Project to constraint set.
x_aniso = jnp.clip(y - f.lmbda * f.A(x), 0.0, 1.0)
Expand All @@ -203,7 +203,7 @@ def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray:
Compute the data fidelity.
"""
df = hist_iso.Objective[-1]
print(f"Data fidelity for isotropic TV was {df:.2e}")
print(f"\nData fidelity for isotropic TV was {df:.2e}")
hist = solver.itstat_object.history(transpose=True)
df = hist.Objective[-1]
print(f"Data fidelity for anisotropic TV was {df:.2e}")
Expand Down
26 changes: 17 additions & 9 deletions scico/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@
import sys

# isort: off
from ._autograd import grad, jacrev, linear_adjoint, value_and_grad, cvjp

import jax, jaxlib

from jax import custom_jvp, custom_vjp, jacfwd, jvp, linearize, vjp, hessian

from . import numpy

# Suppress jax device warning. See https://github.com/google/jax/issues/6805
# This only works for jax>0.3.23; for earlier versions, the getLogger
# argument should be "absl".
logging.getLogger("jax._src.lib.xla_bridge").addFilter(
# argument should be "absl". Two filters are included here due to a change
# in jax between versions 0.4.2 and 0.4.8, both of which are supported by
# scico.
logging.getLogger("jax._src.lib.xla_bridge").addFilter( # jax 0.4.2
logging.Filter("No GPU/TPU found, falling back to CPU.")
)
logging.getLogger("jax._src.xla_bridge").addFilter( # jax 0.4.8
logging.Filter("No GPU/TPU found, falling back to CPU.")
)

# isort: on

import jax
from jax import custom_jvp, custom_vjp, hessian, jacfwd, jvp, linearize, vjp

import jaxlib

from . import numpy
from ._autograd import cvjp, grad, jacrev, linear_adjoint, value_and_grad

__all__ = [
"grad",
"value_and_grad",
Expand Down
8 changes: 3 additions & 5 deletions scico/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2022 by SCICO Developers
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -10,16 +10,14 @@
import os.path
from typing import Optional

from jax.interpreters.xla import DeviceArray

from imageio.v2 import imread

import scico.numpy as snp

__all__ = ["kodim23"]


def _imread(filename: str, path: Optional[str] = None, asfloat: bool = False) -> DeviceArray:
def _imread(filename: str, path: Optional[str] = None, asfloat: bool = False) -> snp.Array:
"""Read an image from disk.

Args:
Expand All @@ -40,7 +38,7 @@ def _imread(filename: str, path: Optional[str] = None, asfloat: bool = False) ->
return im


def kodim23(asfloat: bool = False) -> DeviceArray:
def kodim23(asfloat: bool = False) -> snp.Array:
"""Return the `kodim23` test image.

Args:
Expand Down
13 changes: 6 additions & 7 deletions scico/denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
import scico.numpy as snp
from scico.data import _flax_data_path
from scico.flax import DnCNNNet, FlaxMap, load_weights
from scico.typing import JaxArray


def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = "np"):
def bm3d(x: snp.Array, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = "np"):
r"""An interface to the BM3D denoiser :cite:`dabov-2008-image`.

BM3D denoising is performed using the
Expand Down Expand Up @@ -66,12 +65,12 @@ def bm3d(x: JaxArray, sigma: float, is_rgb: bool = False, profile: Union[BM3DPro

if is_rgb is True:

def bm3d_eval(x: JaxArray, sigma: float):
def bm3d_eval(x: snp.Array, sigma: float):
return tubm3d.bm3d_rgb(x, sigma, profile=profile)

else:

def bm3d_eval(x: JaxArray, sigma: float):
def bm3d_eval(x: snp.Array, sigma: float):
return tubm3d.bm3d(x, sigma, profile=profile)

if snp.util.is_complex_dtype(x.dtype):
Expand Down Expand Up @@ -112,7 +111,7 @@ def bm3d_eval(x: JaxArray, sigma: float):
return y


def bm4d(x: JaxArray, sigma: float, profile: Union[BM4DProfile, str] = "np"):
def bm4d(x: snp.Array, sigma: float, profile: Union[BM4DProfile, str] = "np"):
r"""An interface to the BM4D denoiser :cite:`maggioni-2012-nonlocal`.

BM4D denoising is performed using the
Expand All @@ -134,7 +133,7 @@ def bm4d(x: JaxArray, sigma: float, profile: Union[BM4DProfile, str] = "np"):
if not have_bm4d:
raise RuntimeError("Package bm4d is required for use of this function.")

def bm4d_eval(x: JaxArray, sigma: float):
def bm4d_eval(x: snp.Array, sigma: float):
return tubm4d.bm4d(x, sigma, profile=profile)

if snp.util.is_complex_dtype(x.dtype):
Expand Down Expand Up @@ -230,7 +229,7 @@ def __init__(self, variant: str = "6M"):
variables = load_weights(_flax_data_path("dncnn%s.npz" % variant))
super().__init__(model, variables)

def __call__(self, x: JaxArray, sigma: Optional[float] = None) -> JaxArray:
def __call__(self, x: snp.Array, sigma: Optional[float] = None) -> snp.Array:
r"""Apply DnCNN denoiser.

Args:
Expand Down
Loading