Skip to content

Commit

Permalink
Merge pull request #287 from carterbox/new-split
Browse files Browse the repository at this point in the history
REF: Load batches into contiguous memory blocks
  • Loading branch information
carterbox authored Nov 3, 2023
2 parents 579b62f + 080137f commit dfa8340
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 77 deletions.
216 changes: 187 additions & 29 deletions src/tike/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,29 @@
logger = logging.getLogger(__name__)


def _split_gpu(m, x, dtype):
def _split_gpu(
m: npt.NDArray,
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
return cp.asarray(x[m], dtype=dtype)


def _split_host(m, x, dtype):
def _split_host(
m: npt.NDArray,
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
return np.asarray(x[m], dtype=dtype)

def _split_pinned(m, x, dtype):
unpinned = x[m]
pinned = cupyx.empty_like_pinned(x[m], dtype=dtype)
pinned[:] = unpinned

def _split_pinned(
m: npt.NDArray,
x: npt.ArrayLike,
dtype: npt.DTypeLike,
) -> npt.ArrayLike:
pinned = cupyx.empty_pinned(shape=(len(m), *x.shape[1:]), dtype=dtype)
pinned[...] = x[m]
return pinned


Expand Down Expand Up @@ -94,7 +106,7 @@ def by_scan_stripes(
n: int,
fly: int = 1,
axis: int = 0,
) -> typing.List[npt.NDArray[bool]]:
) -> typing.List[npt.NDArray[np.bool_]]:
"""Return `n` boolean masks that split the field of view into stripes.
Mask divide the data into spatially contiguous regions along the position
Expand Down Expand Up @@ -158,6 +170,142 @@ def by_scan_stripes(
]


def by_scan_stripes_contiguous(
*args,
pool: tike.communicators.ThreadPool,
shape: typing.Tuple[int],
dtype: typing.List[npt.DTypeLike],
destination: typing.List[str],
scan: npt.NDArray[np.float32],
fly: int = 1,
batch_method,
num_batch: int,
) -> typing.Tuple[typing.List[npt.NDArray],
typing.List[typing.List[npt.NDArray]]]:
"""Split data by into stripes and create contiguously ordered batches.
Divide the field of view into one stripe per devices; within each stripe,
create batches according to the batch_method loading the batches into
contiguous blocks in device memory.
Parameters
----------
shape : tuple of int
The number of grid divisions along each dimension.
dtype : List[str]
The datatypes of the args after splitting.
scan : (nscan, 2) float32
The 2D coordinates of the scan positions.
args : (nscan, ...) float32 or None
The arrays to be split by scan position.
fly : int
The number of scan positions per frame.
batch_method :
The method for determining the batches after dividing amongst GPUs
Returns
-------
order : List[array[int]]
The locations of the inputs in the original arrays.
batches : List[List[array[int]]]
The locations of the elements of each batch
scan : List[array[float32]]
The divided 2D coordinates of the scan positions.
args : List[array[float32]] or None
Each input divided into regions or None if arg was None.
"""
if len(shape) != 2:
raise ValueError('The grid shape must have two dimensions.')

map_to_gpu = stripes_equal_count(
population=scan,
num_cluster=shape[0] * shape[1],
dim=0,
)
split_scan = pool.map(
_split_host,
map_to_gpu,
x=scan,
dtype=scan.dtype,
)
batches_noncontiguous: typing.List[typing.List[npt.NDArray]] = pool.map(
getattr(tike.cluster, batch_method),
split_scan,
num_cluster=num_batch,
)
map_to_gpu_contiguous: typing.List[npt.NDArray] = []
batches_contiguous: typing.List[typing.List[npt.NDArray]] = []
for gpu_map, batch_map in zip(map_to_gpu, batches_noncontiguous):
batch_indices = gpu_map[np.concatenate(batch_map)]
map_to_gpu_contiguous.append(batch_indices)
batch_sizes = [len(batch) for batch in batch_map]
batch_breaks = np.cumsum(batch_sizes)[:-1]
batches_contiguous.append(
np.array_split(
np.arange(len(batch_indices)),
batch_breaks,
))

split_args = []
for arg, t, dest in zip([scan, *args], dtype, destination):
if arg is None:
split_args.append(None)
else:
split_args.append(
pool.map(
_split_gpu if dest == 'gpu' else _split_pinned,
map_to_gpu_contiguous,
x=arg,
dtype=t,
))

if __debug__:
for device in batches_contiguous:
assert len(device) == num_batch, (
f"There should be {num_batch} batches, found {len(device)}"
)

return (map_to_gpu_contiguous, batches_contiguous, *split_args)


def stripes_equal_count(
population: npt.ArrayLike,
num_cluster: int,
dim: int = 0,
) -> typing.List[npt.NDArray]:
"""Return indices dividing the population into stripes of equal count.
The returned clusters are divided along the provided dimension into
clusters of approximate equal numbers of elements.
Parameters
----------
population : (M, N) array_like
The M samples of an N dimensional population that needs to be
clustered.
num_cluster : int (0..M]
The number of clusters in which to divide M samples.
dim : int
The dimension (of N) along which the population is divided.
Returns
-------
indicies : (num_cluster,) list of array of integer
The indicies of population that belong to each cluster.
"""
logger.info("Clustering method is stripes.")
xp = cp.get_array_module(population)
if (num_cluster == 1) or (num_cluster >= len(population)):
return np.array_split(np.arange(population.shape[0]), num_cluster)
# Sort the population along the dimension, then split into ranges of approx
# equal size
return np.array_split(
cp.asnumpy(xp.argsort(population[:, dim])),
num_cluster,
)


def wobbly_center(population, num_cluster):
"""Return the indices that divide population into heterogenous clusters.
Expand Down Expand Up @@ -195,12 +343,12 @@ def wobbly_center(population, num_cluster):
"""
logger.info("Clustering method is wobbly center.")
xp = cp.get_array_module(population)
if num_cluster == 1 or num_cluster == population.shape[0]:
return xp.split(xp.arange(population.shape[0]), num_cluster)
if not 0 < num_cluster <= min(0xFFFF, population.shape[0]):
if not 0 < num_cluster < 0xFFFF:
raise ValueError(
f"The number of clusters must be 0 < {num_cluster} < min(65536, M)."
f"The number of clusters must be 0 < {num_cluster} < 65536."
)
if (num_cluster == 1) or (num_cluster >= len(population)):
return np.array_split(np.arange(population.shape[0]), num_cluster)
# Start with the num_cluster observations closest to the global centroid
starting_centroids = xp.argpartition(
xp.linalg.norm(population - xp.mean(population, axis=0, keepdims=True),
Expand Down Expand Up @@ -234,10 +382,10 @@ def wobbly_center(population, num_cluster):


def wobbly_center_random_bootstrap(
population,
num_cluster: int,
boot_fraction: float = 0.95,
):
population,
num_cluster: int,
boot_fraction: float = 0.95,
) -> typing.List[npt.NDArray]:
"""Return the indices that divide population into heterogenous clusters.
Uses a hybrid approach to generate heterogenous clusters. First, a fraction
Expand Down Expand Up @@ -277,12 +425,12 @@ def wobbly_center_random_bootstrap(
"""
logger.info("Clustering method is wobbly center with random bootstrap.")
xp = cp.get_array_module(population)
if num_cluster == 1 or num_cluster == population.shape[0]:
return xp.split(xp.arange(population.shape[0]), num_cluster)
if not 0 < num_cluster <= min(0xFFFF, population.shape[0]):
if not 0 < num_cluster < 0xFFFF:
raise ValueError(
f"The number of clusters must be 0 < {num_cluster} < min(65536, M)."
f"The number of clusters must be 0 < {num_cluster} < 65536."
)
if (num_cluster == 1) or (num_cluster >= len(population)):
return np.array_split(np.arange(population.shape[0]), num_cluster)
# Partially initialize the clusters randomly; each cluster starts with an
# equal number of members
num_bootstrap = int(len(population) * boot_fraction)
Expand Down Expand Up @@ -349,12 +497,12 @@ def compact(population, num_cluster, max_iter=500):
logger.info("Clustering method is compact.")
# Indexing and serial operations is very slow on GPU, so always use host
population = cp.asnumpy(population)
if num_cluster == 1 or num_cluster == population.shape[0]:
return np.split(np.arange(population.shape[0]), num_cluster)
if not 0 < num_cluster <= min(0xFFFF, population.shape[0]):
if not 0 < num_cluster < 0xFFFF:
raise ValueError(
f"The number of clusters must be 0 < {num_cluster} < min(65536, M)."
f"The number of clusters must be 0 < {num_cluster} < 65536."
)
if (num_cluster == 1) or (num_cluster >= len(population)):
return np.array_split(np.arange(population.shape[0]), num_cluster)
# Define a constant array used for indexing.
_all = np.arange(len(population))
# Specify the number of points allowed in each cluster
Expand All @@ -363,8 +511,7 @@ def compact(population, num_cluster, max_iter=500):
max_size[:len(population) % num_cluster] += 1
assert np.sum(max_size) == len(population), (
f"Sum of cluster maximums {np.sum(max_size)} should "
f"equal the population size {len(population)}!"
)
f"equal the population size {len(population)}!")

# Use kmeans++ to choose initial cluster centers
starting_centroids = np.zeros(num_cluster, dtype='int')
Expand Down Expand Up @@ -420,8 +567,7 @@ def compact(population, num_cluster, max_iter=500):
_unassigned.remove(p)
size[nearest[p]] += 1
assert size[nearest[p]] <= max_size[nearest[p]], (
f"{size[nearest[p]]} !<= {max_size[nearest[p]]}"
)
f"{size[nearest[p]]} !<= {max_size[nearest[p]]}")
if size[nearest[p]] >= max_size[nearest[p]]:
unfilled_clusters.remove(nearest[p])
# re-start with one less available cluster
Expand Down Expand Up @@ -498,8 +644,8 @@ def _k_means_objective(population, labels, num_cluster):
for c in range(num_cluster):
weight = xp.sum(labels == c)
if weight > 1:
cost += weight * abs(xp.linalg.det(
xp.cov(
cost += weight * abs(
xp.linalg.det(xp.cov(
population[labels == c],
rowvar=False,
)))
Expand Down Expand Up @@ -532,3 +678,15 @@ def cluster_compact(*args, **kwargs):
DeprecationWarning,
)
return compact(*args, **kwargs)


def _batch_ends(
num_batch: int,
size: int,
index: int,
) -> typing.Tuple[int, int]:
batch_size = size // num_batch
remainder = size % num_batch
lo = batch_size * index + min(remainder, index)
hi = lo + batch_size + (1 if index < remainder else 0)
return (lo, hi)
28 changes: 11 additions & 17 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,26 +336,28 @@ def __enter__(self):
warnings.warn(
"Diffraction patterns contain invalid data. "
"All data should be non-negative and finite.", UserWarning)
odd_pool = self.comm.pool.num_workers % 2

(
self.comm.order,
self.batches,
self.parameters.scan,
self.data,
self.parameters.eigen_weights,
) = tike.cluster.by_scan_grid(
) = tike.cluster.by_scan_stripes_contiguous(
self.data,
self.parameters.eigen_weights,
scan=self.parameters.scan,
pool=self.comm.pool,
shape=(
self.comm.pool.num_workers
if odd_pool else self.comm.pool.num_workers // 2,
1 if odd_pool else 2,
shape=(self.comm.pool.num_workers, 1),
dtype=(
tike.precision.floating,
tike.precision.floating
if self.data.itemsize > 2 else self.data.dtype,
tike.precision.floating,
),
dtype=(tike.precision.floating, tike.precision.floating
if self.data.itemsize > 2 else self.data.dtype,
tike.precision.floating),
destination=('gpu', 'pinned', 'gpu'),
batch_method=self.parameters.algorithm_options.batch_method,
num_batch=self.parameters.algorithm_options.num_batch,
)

self.parameters.psi = self.comm.pool.bcast(
Expand Down Expand Up @@ -388,14 +390,6 @@ def __enter__(self):
for x in self.comm.order),
)

# Unique batch for each device
self.batches = self.comm.pool.map(
getattr(tike.cluster,
self.parameters.algorithm_options.batch_method),
self.parameters.scan,
num_cluster=self.parameters.algorithm_options.num_batch,
)

self.parameters.probe = _rescale_probe(
self.operator,
self.comm,
Expand Down
Loading

0 comments on commit dfa8340

Please sign in to comment.