Skip to content

Commit

Permalink
Merge pull request #290 from carterbox/fuse-reduction
Browse files Browse the repository at this point in the history
NEW: Use CuPy fuse to merge some reduction kernels
  • Loading branch information
carterbox authored Oct 31, 2023
2 parents c76e56a + e2abb84 commit 579b62f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
15 changes: 11 additions & 4 deletions src/tike/operators/cupy/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@

import numpy.typing as npt
import numpy as np
import cupy as cp

from .operator import Operator
from .propagation import Propagation
from .convolution import Convolution
from . import objective


@cp.fuse()
def _intensity_from_farplane(farplane):
return cp.sum(
cp.real(farplane * cp.conj(farplane)),
axis=tuple(range(1, farplane.ndim - 2)),
)


class Ptycho(Operator):
"""A Ptychography operator.
Expand Down Expand Up @@ -159,10 +169,7 @@ def _compute_intensity(
scan=scan,
probe=probe,
)
return self.xp.sum(
(farplane * farplane.conj()).real,
axis=tuple(range(1, farplane.ndim - 2)),
), farplane
return _intensity_from_farplane(farplane), farplane

def cost(
self,
Expand Down
28 changes: 19 additions & 9 deletions src/tike/ptycho/solvers/_preconditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ def _rolling_average(old, new):
return 0.5 * (new + old)


@cp.fuse()
def _probe_amp_sum(probe):
return cp.sum(
probe * cp.conj(probe),
axis=-3,
)


def _psi_preconditioner(
psi: npt.NDArray[tike.precision.cfloating],
scan: npt.NDArray[tike.precision.floating],
Expand All @@ -31,10 +39,7 @@ def make_certain_args_constant(
scan = ind_args[0]
psi_update_denominator = mod_args[0]

probe_amp = cp.sum(
probe * probe.conj(),
axis=-3,
)[:, 0]
probe_amp = _probe_amp_sum(probe)[:, 0]
psi_update_denominator = operator.diffraction.patch.adj(
patches=probe_amp,
images=psi_update_denominator,
Expand All @@ -59,6 +64,15 @@ def make_certain_args_constant(
)[0]


@cp.fuse()
def _patch_amp_sum(patches):
return cp.sum(
patches * cp.conj(patches),
axis=0,
keepdims=False,
)


def _probe_preconditioner(
psi: npt.NDArray[tike.precision.cfloating],
scan: npt.NDArray[tike.precision.floating],
Expand All @@ -81,11 +95,7 @@ def make_certain_args_constant(
positions=scan,
patch_width=probe.shape[-1],
)
probe_update_denominator += cp.sum(
patches * patches.conj(),
axis=0,
keepdims=False,
)
probe_update_denominator += _patch_amp_sum(patches)
assert probe_update_denominator.ndim == 2
return (probe_update_denominator,)

Expand Down
3 changes: 3 additions & 0 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def print_sample_error(indices):
p0 = print_sample_error(
tike.cluster.wobbly_center(self.population, self.num_cluster))
print('random sample')
np.random.seed(0)
p1 = print_sample_error(batch_indicies(self.num_pop, self.num_cluster))

# We should be more condifent that wobbly samples are the same
Expand Down Expand Up @@ -148,6 +149,7 @@ def print_sample_error(indices):
p0 = print_sample_error(
tike.cluster.wobbly_center(self.population, self.num_cluster))
print('random sample')
np.random.seed(0)
p1 = print_sample_error(batch_indicies(self.num_pop, self.num_cluster))

# We should be more condifent that wobbly samples are the same
Expand Down Expand Up @@ -194,6 +196,7 @@ def print_sample_error(indices):
p0 = print_sample_error(
tike.cluster.compact(self.population, self.num_cluster))
print('random sample')
np.random.seed(0)
p1 = print_sample_error(batch_indicies(self.num_pop, self.num_cluster))

# Every compact cluster should have smaller devation than a random
Expand Down

0 comments on commit 579b62f

Please sign in to comment.