Skip to content

Commit

Permalink
Merge branch 'main' into fix/mpi4py-4-support
Browse files Browse the repository at this point in the history
  • Loading branch information
mtar authored Aug 20, 2024
2 parents 43ee9f9 + 15c4478 commit 941d32f
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 24 deletions.
12 changes: 5 additions & 7 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ body:
description: What version of Heat are you running?
options:
- main (development branch)
- 1.4.x
- 1.3.x
- 1.2.x
validations:
required: true
- type: dropdown
Expand All @@ -44,23 +44,21 @@ body:
label: Python version
description: What Python version?
options:
- 3.12
- 3.11
- "3.10"
- 3.9
- 3.8
- type: dropdown
id: pytorch-version
attributes:
label: PyTorch version
description: What PyTorch version?
options:
- 2.4
- 2.3
- 2.2
- 2.1
- 2.0
- 1.13
- 1.12
- 1.11
- "1.10"
- '2.0'
- type: textarea
id: mpi-version
attributes:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- 'torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2'
- 'torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2'
- 'torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1'
- 'torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0'
exclude:
- py-version: '3.12'
pytorch-version: 'torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2'
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Heat is a distributed tensor framework for high performance data analytics.
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.2531472.svg)](https://doi.org/10.5281/zenodo.2531472)
[![Benchmarks](https://img.shields.io/badge/Grafana-Benchmarks-2ea44f)](https://57bc8d92-72f2-4869-accd-435ec06365cb.ka.bw-cloud-instance.org:3000/d/adjpqduq9r7k0a/heat-cb?orgId=1)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![JuRSE Code Pick of the Month](https://img.shields.io/badge/JuRSE_Code_Pick-August_2024-blue)](https://www.fz-juelich.de/en/rse/jurse-community/jurse-code-of-the-month/august-2024)

# Table of Contents
- [What is Heat for?](#what-is-heat-for)
Expand Down
38 changes: 23 additions & 15 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,18 @@ def __prephalo(self, start, end) -> torch.Tensor:

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size: int) -> torch.Tensor:
def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor:
"""
Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``.
Parameters
----------
halo_size : int
Size of the halo.
prev : bool, optional
If True, fetch the halo from the previous rank. Default: True.
next : bool, optional
If True, fetch the halo from the next rank. Default: True.
"""
if not isinstance(halo_size, int):
raise TypeError(
Expand Down Expand Up @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor:
req_list = []

# exchange data with next populated process
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=next_rank))
if prev:
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
if rank != first_rank:
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=prev_rank))

if rank != first_rank:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=prev_rank))
if next:
if rank != first_rank:
req_list.append(self.comm.Isend(a_prev, prev_rank))
if rank != last_rank:
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=next_rank))

for req in req_list:
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
self.__halo_next = res_next
self.__halo_prev = res_prev
self.__ishalo = True

def __cat_halo(self) -> torch.Tensor:
Expand Down
90 changes: 90 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"unique",
"vsplit",
"vstack",
"unfold",
]


Expand Down Expand Up @@ -4213,3 +4214,92 @@ def mpi_topk(a, b, mpi_type):


MPI_TOPK = MPI.Op.Create(mpi_topk, commute=True)


def unfold(a: DNDarray, axis: int, size: int, step: int = 1):
"""
Returns a DNDarray which contains all slices of size `size` in the axis `axis`.
Behaves like torch.Tensor.unfold for DNDarrays. [torch.Tensor.unfold](https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html)
Parameters
----------
a : DNDarray
array to unfold
axis : int
axis in which unfolding happens
size : int
the size of each slice that is unfolded, must be greater than 1
step : int
the step between each slice, must be at least 1
Example:
```
>>> x = ht.arange(1., 8)
>>> x
DNDarray([1., 2., 3., 4., 5., 6., 7.], dtype=ht.float32, device=cpu:0, split=e)
>>> ht.unfold(x, 0, 2, 1)
DNDarray([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.]], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.unfold(x, 0, 2, 2)
DNDarray([[1., 2.],
[3., 4.],
[5., 6.]], dtype=ht.float32, device=cpu:0, split=None)
```
Note
---------
You have to make sure that every node has at least chunk size size-1 if the split axis of the array is the unfold axis.
"""
if step < 1:
raise ValueError("step must be >= 1.")
if size <= 1:
raise ValueError("size must be > 1.")
axis = stride_tricks.sanitize_axis(a.shape, axis)
if size > a.shape[axis]:
raise ValueError(
f"maximum size for DNDarray at axis {axis} is {a.shape[axis]} but size is {size}."
)

comm = a.comm
dev = a.device
tdev = dev.torch_device

if a.split is None or comm.size == 1 or a.split != axis: # early out
ret = factories.array(
a.larray.unfold(axis, size, step), is_split=a.split, device=dev, comm=comm
)

return ret
else: # comm.size > 1 and split axis == unfold axis
# index range [0:sizedim-1-(size-1)] = [0:sizedim-size]
# --> size of axis: ceil((sizedim-size+1) / step) = floor(sizedim-size) / step)) + 1
# ret_shape = (*a_shape[:axis], int((a_shape[axis]-size)/step) + 1, a_shape[axis+1:], size)

if (size - 1 > a.lshape_map[:, axis]).any():
raise RuntimeError("Chunk-size needs to be at least size - 1.")
a.get_halo(size - 1, prev=False)

counts, displs = a.counts_displs()
displs = torch.tensor(displs, device=tdev)

# min local index in unfold axis
min_index = ((displs[comm.rank] - 1) // step + 1) * step - displs[comm.rank]
if min_index >= a.lshape[axis] or (
comm.rank == comm.size - 1 and min_index + size > a.lshape[axis]
):
loc_unfold_shape = list(a.lshape)
loc_unfold_shape[axis] = 0
ret_larray = torch.zeros((*loc_unfold_shape, size), device=tdev)
else: # unfold has local data
ret_larray = a.array_with_halos[
axis * (slice(None, None, None),) + (slice(min_index, None, None), Ellipsis)
].unfold(axis, size, step)

ret = factories.array(ret_larray, is_split=axis, device=dev, comm=comm)

return ret
60 changes: 60 additions & 0 deletions heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3752,3 +3752,63 @@ def test_vstack(self):
b = ht.ones((12,), split=0)
res = ht.vstack((a, b))
self.assertEqual(res.shape, (2, 12))

def test_unfold(self):
dtypes = (ht.int, ht.float)

for dtype in dtypes: # test with different datatypes
# exceptions
n = 1000
x = ht.arange(n, dtype=dtype)
with self.assertRaises(ValueError): # size too small
ht.unfold(x, 0, 1, 1)
with self.assertRaises(ValueError): # step too small
ht.unfold(x, 0, 2, 0)
x.resplit_(0)
min_chunk_size = x.lshape_map[:, 0].min().item()
if min_chunk_size + 2 > n: # size too large
with self.assertRaises(ValueError):
ht.unfold(x, 0, min_chunk_size + 2)
else: # size too large for chunk_size
with self.assertRaises(RuntimeError):
ht.unfold(x, 0, min_chunk_size + 2)
with self.assertRaises(ValueError): # size too large
ht.unfold(x, 0, n + 1, 1)
ht.unfold(
x, 0, min_chunk_size, min_chunk_size + 1
) # no fully local unfolds on some nodes

# 2D sliding views
n = 100

x = torch.arange(n * n).reshape((n, n))
y = ht.array(x, dtype)
y.resplit_(0)

u = x.unfold(0, 3, 3)
u = u.unfold(1, 3, 3)
u = ht.array(u)
v = ht.unfold(y, 0, 3, 3)
v = ht.unfold(v, 1, 3, 3)

self.assertTrue(ht.equal(u, v))

# more dimensions, different split axes
n = 53
k = 3 # number of dimensions
shape = k * (n,)
size = n**k

x = torch.arange(size).reshape(shape)
_y = x.clone().detach()
y = ht.array(_y, dtype)

for split in (None, *range(k)):
y.resplit_(split)
for size in range(2, 9):
for step in range(1, 21):
for dimension in range(k):
u = ht.array(x.unfold(dimension, size, step))
v = ht.unfold(y, dimension, size, step)

self.assertTrue(ht.equal(u, v))
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
install_requires=[
"mpi4py>=3.0.0",
"numpy>=1.22.0, <2",
"torch>=2.0.0, <2.3.2",
"torch>=2.0.0, <2.4.1",
"scipy>=1.10.0",
"pillow>=6.0.0",
"torchvision>=0.15.2",
"torchvision>=0.15.2, <0.19.1",
],
extras_require={
"docutils": ["docutils>=0.16"],
Expand Down

0 comments on commit 941d32f

Please sign in to comment.