Skip to content

Commit

Permalink
torch20 transfer (#86)
Browse files Browse the repository at this point in the history
* 🎨 formatting

* πŸ”₯ remove gain from dask storage

* 🎨 formatting

* 🎨 formatting

* ⚑️ change CLI defaults

* ✨ new loss enum

* ⚑️ relative import

* βœ… update tests

* πŸ”’οΈ upgrade setuptools
  • Loading branch information
jgrss authored Aug 14, 2024
1 parent e6dc57d commit d8bcfb2
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 172 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ tensorboard>=2.2.0
PyYAML>=5.1
geowombat@git+https://github.com/jgrss/geowombat.git
tsaug@git+https://github.com/jgrss/tsaug.git
setuptools==59.5.0
setuptools>=70
numpydoc
sphinx
sphinx-automodapi
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
'setuptools>=65.5.1',
'setuptools>=70',
'wheel',
'numpy<2,>=1.22',
]
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package_dir=
packages=find:
include_package_data = True
setup_requires =
setuptools>=65.5.1
setuptools>=70
wheel
numpy<2,>=1.22
python_requires =
Expand Down Expand Up @@ -63,7 +63,6 @@ install_requires =
geowombat@git+https://github.com/jgrss/geowombat.git
tsaug@git+https://github.com/jgrss/tsaug.git
pygrts@git+https://github.com/jgrss/[email protected]
setuptools>=65.5.1

[options.extras_require]
docs = numpydoc
Expand Down
144 changes: 100 additions & 44 deletions src/cultionet/data/create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing as T
from pathlib import Path

Expand All @@ -11,6 +12,7 @@
import torch
import xarray as xr
from affine import Affine
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster, progress
from rasterio.windows import Window, from_bounds
from scipy.ndimage import label as nd_label
Expand Down Expand Up @@ -71,19 +73,22 @@ def reshape_and_mask_array(
num_bands: int,
gain: float,
offset: int,
apply_gain: bool = True,
) -> xr.DataArray:
"""Reshapes an array and masks no-data values."""

src_ts_stack = xr.DataArray(
# Date are stored [(band x time) x height x width]
dtype = 'float32' if apply_gain else 'int16'

time_series = xr.DataArray(
# Data are stored [(band x time) x height x width]
(
data.data.reshape(
num_bands,
num_time,
data.gw.nrows,
data.gw.ncols,
).transpose(1, 0, 2, 3)
).astype('float32'),
).astype(dtype),
dims=('time', 'band', 'y', 'x'),
coords={
'time': range(num_time),
Expand All @@ -94,12 +99,18 @@ def reshape_and_mask_array(
attrs=data.attrs.copy(),
)

with xr.set_options(keep_attrs=True):
time_series = (src_ts_stack.gw.mask_nodata() * gain + offset).fillna(0)
if apply_gain:

with xr.set_options(keep_attrs=True):
# Mask and scale the data
time_series = (
time_series.gw.mask_nodata() * gain + offset
).fillna(0)

return time_series


@threadpool_limits.wrap(limits=1, user_api="blas")
def create_predict_dataset(
image_list: T.List[T.List[T.Union[str, Path]]],
region: str,
Expand All @@ -113,26 +124,36 @@ def create_predict_dataset(
padding: int = 101,
num_workers: int = 1,
compress_method: T.Union[int, str] = 'zlib',
use_cluster: bool = True,
):
"""Creates a prediction dataset for an image."""

# Read windows larger than the re-chunk window size
read_chunksize = 1024
while True:
if read_chunksize < window_size:
read_chunksize *= 2
else:
break

with gw.config.update(ref_res=ref_res):
with gw.open(
image_list,
stack_dim="band",
band_names=list(range(1, len(image_list) + 1)),
resampling=resampling,
chunks=512,
chunks=read_chunksize,
) as src_ts:

# Get the time and band count
num_time, num_bands = get_image_list_dims(image_list, src_ts)

time_series: xr.DataArray = reshape_and_mask_array(
time_series = reshape_and_mask_array(
data=src_ts,
num_time=num_time,
num_bands=num_bands,
gain=gain,
offset=offset,
apply_gain=False,
)

# Chunk the array into the windows
Expand Down Expand Up @@ -172,42 +193,77 @@ def create_predict_dataset(
trim=False,
)

with dask.config.set(
{
"distributed.worker.memory.terminate": False,
"distributed.comm.retry.count": 10,
"distributed.comm.timeouts.connect": 5,
"distributed.scheduler.allowed-failures": 20,
}
):
with LocalCluster(
processes=True,
n_workers=num_workers,
threads_per_worker=1,
memory_target_fraction=0.97,
memory_limit="4GB", # per worker limit
) as cluster:
with Client(cluster) as client:
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem, format=date_format
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem, format=date_format
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
gain=gain,
) as batch_store:
save_tasks = batch_store.save(time_series_array)
results = client.persist(save_tasks)
progress(results)
if use_cluster:
with dask.config.set(
{
"distributed.worker.memory.terminate": False,
"distributed.comm.retry.count": 10,
"distributed.comm.timeouts.connect": 5,
"distributed.scheduler.allowed-failures": 20,
"distributed.worker.memory.pause": 0.95,
"distributed.worker.memory.target": 0.97,
"distributed.worker.memory.spill": False,
"distributed.scheduler.worker-saturation": 1.0,
}
):
with LocalCluster(
processes=True,
n_workers=num_workers,
threads_per_worker=1,
memory_limit="6GB", # per worker limit
silence_logs=logging.ERROR,
) as cluster:
with Client(cluster) as client:
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem,
format=date_format,
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem,
format=date_format,
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(
time_series_array
)
results = client.gather(
client.persist(save_tasks)
)
progress(results)

else:

with dask.config.set(
scheduler='processes', num_workers=num_workers
):
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem, format=date_format
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem, format=date_format
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(time_series_array)
with ProgressBar():
save_tasks.compute()


class ReferenceArrays:
Expand Down
4 changes: 1 addition & 3 deletions src/cultionet/data/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
window_size: int,
padding: int,
compress_method: Union[int, str],
gain: float,
):
self.data = data
self.res = res
Expand All @@ -43,7 +42,6 @@ def __init__(
self.window_size = window_size
self.padding = padding
self.compress_method = compress_method
self.gain = gain

def __setitem__(self, key: tuple, item: np.ndarray) -> None:
time_range, index_range, y, x = key
Expand Down Expand Up @@ -87,7 +85,7 @@ def write_batch(self, x: np.ndarray, w: Window, w_pad: Window):
)

x = einops.rearrange(
torch.from_numpy(x / self.gain).to(dtype=torch.int32),
torch.from_numpy(x.astype('int32')).to(dtype=torch.int32),
't c h w -> 1 c t h w',
)

Expand Down
1 change: 1 addition & 0 deletions src/cultionet/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LossTypes(StrEnum):
CLASS_BALANCED_MSE = "ClassBalancedMSELoss"
TANIMOTO_COMPLEMENT = "TanimotoComplementLoss"
TANIMOTO = "TanimotoDistLoss"
TANIMOTO_COMBINED = "TanimotoCombined"
TOPOLOGY = "TopologyLoss"


Expand Down
1 change: 1 addition & 0 deletions src/cultionet/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .losses import (
BoundaryLoss,
ClassBalancedMSELoss,
CombinedLoss,
LossPreprocessing,
TanimotoComplementLoss,
TanimotoDistLoss,
Expand Down
Loading

0 comments on commit d8bcfb2

Please sign in to comment.