From 6ef9b9c24dfa46efc59e012668def6e90d020a1b Mon Sep 17 00:00:00 2001 From: Jordan Graesser Date: Thu, 15 Aug 2024 17:17:22 +1000 Subject: [PATCH] fix: torch20 ex (#87) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ⚡️ sum weights * ⚡️ move sum * ✅ update loss tests * 🎨 formatting * ⚡️ use common method * 🎨 formatting * ➖ make kornia optional * ➕ increase dependency versions --- setup.cfg | 8 +- src/cultionet/data/create.py | 213 +++++++++++++++------------------ src/cultionet/data/datasets.py | 4 +- src/cultionet/data/store.py | 15 ++- src/cultionet/losses/losses.py | 25 ++-- tests/test_loss.py | 10 +- 6 files changed, 132 insertions(+), 143 deletions(-) diff --git a/setup.cfg b/setup.cfg index 56c0466e..e02e65b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,9 @@ python_requires = >=3.9,<3.11 install_requires = attrs>=21 + dask>=2024.8.0 + distributed>=2024.8.0 + xarray>=2024.7.0 frozendict>=2.2 frozenlist>=1.3 numpy<2,>=1.22 @@ -45,7 +48,7 @@ install_requires = decorator==4.4.2 rtree>=0.9.7 graphviz>=0.19 - tqdm>=4.62 + tqdm>=4.66 pyDeprecate==0.3.1 future>=0.17.1 tensorboard>=2.2 @@ -53,13 +56,12 @@ install_requires = lightning>=2.2 torchmetrics>=1.3 einops>=0.7 - ray<=2.1,>=2 + ray>=2.34 pyarrow>=11 typing-extensions lz4 rich-argparse pyogrio>=0.7 - kornia>=0.7 geowombat@git+https://github.com/jgrss/geowombat.git tsaug@git+https://github.com/jgrss/tsaug.git pygrts@git+https://github.com/jgrss/pygrts.git@v1.4.1 diff --git a/src/cultionet/data/create.py b/src/cultionet/data/create.py index 962de174..ef16db7d 100644 --- a/src/cultionet/data/create.py +++ b/src/cultionet/data/create.py @@ -1,20 +1,22 @@ -import logging import typing as T from pathlib import Path -import dask import dask.array as da import einops import geopandas as gpd import geowombat as gw import numpy as np import pandas as pd +import psutil +import ray import torch import xarray as xr from affine import Affine from dask.diagnostics import ProgressBar -from dask.distributed import Client, LocalCluster, progress +from psutil._common import bytes2human from rasterio.windows import Window, from_bounds +from ray.exceptions import RayTaskError +from ray.util.dask import enable_dask_on_ray, ray_dask_get from scipy.ndimage import label as nd_label from skimage.measure import regionprops from threadpoolctl import threadpool_limits @@ -110,7 +112,6 @@ def reshape_and_mask_array( return time_series -@threadpool_limits.wrap(limits=1, user_api="blas") def create_predict_dataset( image_list: T.List[T.List[T.Union[str, Path]]], region: str, @@ -121,130 +122,102 @@ def create_predict_dataset( ref_res: T.Union[float, T.Tuple[float, float]] = 10.0, resampling: str = "nearest", window_size: int = 100, - padding: int = 101, + padding: int = 20, num_workers: int = 1, compress_method: T.Union[int, str] = 'zlib', - use_cluster: bool = True, ): """Creates a prediction dataset for an image.""" # Read windows larger than the re-chunk window size - read_chunksize = 1024 + read_chunksize = 256 while True: - if read_chunksize < window_size: + if read_chunksize < window_size + padding: read_chunksize *= 2 else: break - with gw.config.update(ref_res=ref_res): - with gw.open( - image_list, - stack_dim="band", - band_names=list(range(1, len(image_list) + 1)), - resampling=resampling, - chunks=read_chunksize, - ) as src_ts: - # Get the time and band count - num_time, num_bands = get_image_list_dims(image_list, src_ts) - - time_series = reshape_and_mask_array( - data=src_ts, - num_time=num_time, - num_bands=num_bands, - gain=gain, - offset=offset, - apply_gain=False, - ) + total_cpus = psutil.cpu_count(logical=True) + threads_per_worker = total_cpus // num_workers - # Chunk the array into the windows - time_series_array = time_series.chunk( - {"time": -1, "band": -1, "y": window_size, "x": window_size} - ).data - - # Check if the array needs to be padded - # First, get the end chunk size of rows and columns - height_end_chunk = time_series_array.chunks[-2][-1] - width_end_chunk = time_series_array.chunks[-1][-1] - - height_padding = 0 - width_padding = 0 - if padding > height_end_chunk: - height_padding = padding - height_end_chunk - if padding > width_end_chunk: - width_padding = padding - width_end_chunk - - if (height_padding > 0) or (width_padding > 0): - # Pad the full array if the end chunk is smaller than the padding - time_series_array = da.pad( - time_series_array, - pad_width=( - (0, 0), - (0, 0), - (0, height_padding), - (0, width_padding), - ), - ).rechunk({0: -1, 1: -1, 2: window_size, 3: window_size}) - - # Add the padding to each chunk - time_series_array = time_series_array.map_overlap( - lambda x: x, - depth={0: 0, 1: 0, 2: padding, 3: padding}, - boundary=0, - trim=False, - ) + logger.info(f"Opening images with window chunk sizes of {read_chunksize}.") + logger.info( + f"Re-chunking image arrays to chunk sizes of {window_size} with padding of {padding}." + ) + logger.info( + f"Virtual memory available is {bytes2human(psutil.virtual_memory().available)}." + ) + logger.info( + f"Creating PyTorch dataset with {num_workers} processes and {threads_per_worker} threads." + ) + + with threadpool_limits(limits=threads_per_worker, user_api="blas"): + + with gw.config.update(ref_res=ref_res): + with gw.open( + image_list, + stack_dim="band", + band_names=list(range(1, len(image_list) + 1)), + resampling=resampling, + chunks=read_chunksize, + ) as src_ts: + # Get the time and band count + num_time, num_bands = get_image_list_dims(image_list, src_ts) - if use_cluster: - with dask.config.set( + time_series = reshape_and_mask_array( + data=src_ts, + num_time=num_time, + num_bands=num_bands, + gain=gain, + offset=offset, + apply_gain=False, + ) + + # Chunk the array into the windows + time_series_array = time_series.chunk( { - "distributed.worker.memory.terminate": False, - "distributed.comm.retry.count": 10, - "distributed.comm.timeouts.connect": 5, - "distributed.scheduler.allowed-failures": 20, - "distributed.worker.memory.pause": 0.95, - "distributed.worker.memory.target": 0.97, - "distributed.worker.memory.spill": False, - "distributed.scheduler.worker-saturation": 1.0, + "time": -1, + "band": -1, + "y": window_size, + "x": window_size, } - ): - with LocalCluster( - processes=True, - n_workers=num_workers, - threads_per_worker=1, - memory_limit="6GB", # per worker limit - silence_logs=logging.ERROR, - ) as cluster: - with Client(cluster) as client: - with BatchStore( - data=time_series, - write_path=process_path, - res=ref_res, - resampling=resampling, - region=region, - start_date=pd.to_datetime( - Path(image_list[0]).stem, - format=date_format, - ).strftime("%Y%m%d"), - end_date=pd.to_datetime( - Path(image_list[-1]).stem, - format=date_format, - ).strftime("%Y%m%d"), - window_size=window_size, - padding=padding, - compress_method=compress_method, - ) as batch_store: - save_tasks = batch_store.save( - time_series_array - ) - results = client.gather( - client.persist(save_tasks) - ) - progress(results) - - else: - - with dask.config.set( - scheduler='processes', num_workers=num_workers - ): + ).data + + # Check if the array needs to be padded + # First, get the end chunk size of rows and columns + height_end_chunk = time_series_array.chunks[-2][-1] + width_end_chunk = time_series_array.chunks[-1][-1] + + height_padding = 0 + width_padding = 0 + if padding > height_end_chunk: + height_padding = padding - height_end_chunk + if padding > width_end_chunk: + width_padding = padding - width_end_chunk + + if (height_padding > 0) or (width_padding > 0): + # Pad the full array if the end chunk is smaller than the padding + time_series_array = da.pad( + time_series_array, + pad_width=( + (0, 0), + (0, 0), + (0, height_padding), + (0, width_padding), + ), + ).rechunk({0: -1, 1: -1, 2: window_size, 3: window_size}) + + # Add the padding to each chunk + time_series_array = time_series_array.map_overlap( + lambda x: x, + depth={0: 0, 1: 0, 2: padding, 3: padding}, + boundary=0, + trim=False, + ) + + if not ray.is_initialized(): + ray.init(num_cpus=num_workers) + + try: with BatchStore( data=time_series, write_path=process_path, @@ -261,9 +234,17 @@ def create_predict_dataset( padding=padding, compress_method=compress_method, ) as batch_store: - save_tasks = batch_store.save(time_series_array) - with ProgressBar(): - save_tasks.compute() + batch_store.save( + time_series_array, + scheduler=ray_dask_get, + ) + + except RayTaskError as e: + logger.warning(e) + ray.shutdown() + + if ray.is_initialized(): + ray.shutdown() class ReferenceArrays: diff --git a/src/cultionet/data/datasets.py b/src/cultionet/data/datasets.py index 805e12d3..b2edc75b 100644 --- a/src/cultionet/data/datasets.py +++ b/src/cultionet/data/datasets.py @@ -377,7 +377,7 @@ def split_train_val( return train_ds, val_ds def load_file(self, filename: T.Union[str, Path]) -> Data: - return joblib.load(filename) + return Data.from_file(filename) def __getitem__( self, idx: T.Union[int, np.ndarray] @@ -400,7 +400,7 @@ def get(self, idx: int) -> dict: idx (int): The dataset index position. """ - batch = Data.from_file(self.data_list_[idx]) + batch = self.load_file(self.data_list_[idx]) batch.x = (batch.x * 1e-4).clip(1e-9, 1) diff --git a/src/cultionet/data/store.py b/src/cultionet/data/store.py index b598b472..5e0cc11f 100644 --- a/src/cultionet/data/store.py +++ b/src/cultionet/data/store.py @@ -10,9 +10,13 @@ from dask.delayed import Delayed from dask.utils import SerializableLock from rasterio.windows import Window +from retry import retry +from ..utils.logging import set_color_logger from .data import Data +logger = set_color_logger(__name__) + class BatchStore: """``dask.array.store`` for data batches.""" @@ -61,6 +65,7 @@ def __setitem__(self, key: tuple, item: np.ndarray) -> None: self.write_batch(item, w=item_window, w_pad=pad_window) + @retry(IOError, tries=5, delay=1) def write_batch(self, x: np.ndarray, w: Window, w_pad: Window): image_height = self.window_size + self.padding * 2 image_width = self.window_size + self.padding * 2 @@ -133,8 +138,14 @@ def write_batch(self, x: np.ndarray, w: Window, w_pad: Window): compress=self.compress_method, ) + try: + _ = batch.from_file(self.write_path / f"{batch_id}.pt") + except EOFError: + raise IOError + def __enter__(self) -> "BatchStore": self.closed = False + return self def __exit__(self, exc_type, exc_value, traceback): @@ -143,5 +154,5 @@ def __exit__(self, exc_type, exc_value, traceback): def _open(self) -> "BatchStore": return self - def save(self, data: da.Array) -> Delayed: - return da.store(data, self, lock=self.lock_, compute=False) + def save(self, data: da.Array, **kwargs) -> Delayed: + da.store(data, self, lock=self.lock_, compute=True, **kwargs) diff --git a/src/cultionet/losses/losses.py b/src/cultionet/losses/losses.py index 9bd84502..e4ae321f 100644 --- a/src/cultionet/losses/losses.py +++ b/src/cultionet/losses/losses.py @@ -4,7 +4,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from kornia.contrib import distance_transform + +try: + from kornia.contrib import distance_transform +except ImportError: + distance_transform = None try: import torch_topological.nn as topnn @@ -198,7 +202,6 @@ def tanimoto_dist( ypred: torch.Tensor, ytrue: torch.Tensor, smooth: float, - weights: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Tanimoto distance.""" @@ -225,19 +228,13 @@ def tanimoto_dist( tpl = tpl.sum(dim=(2, 3)) sq_sum = sq_sum.sum(dim=(2, 3)) - numerator = (tpl * batch_weight) + smooth - denominator = ((sq_sum - tpl) * batch_weight) + smooth + numerator = (tpl * batch_weight).sum(dim=-1) + smooth + denominator = ((sq_sum - tpl) * batch_weight).sum(dim=-1) + smooth distance = numerator / denominator loss = 1.0 - distance - # Apply weights - if weights is not None: - loss = (loss * weights).sum(dim=1) / weights.sum() - else: - loss = loss.mean(dim=1) - - return distance + return loss class TanimotoDistLoss(nn.Module): @@ -281,16 +278,12 @@ class TanimotoDistLoss(nn.Module): def __init__( self, smooth: float = 1e-5, - beta: T.Optional[float] = 0.999, - class_counts: T.Optional[torch.Tensor] = None, transform_logits: bool = False, one_hot_targets: bool = True, ): super().__init__() self.smooth = smooth - self.beta = beta - self.class_counts = class_counts self.preprocessor = LossPreprocessing( transform_logits=transform_logits, @@ -398,6 +391,8 @@ class BoundaryLoss(nn.Module): def __init__(self): super().__init__() + assert distance_transform is not None + def fill_distances( self, distances: torch.Tensor, diff --git a/tests/test_loss.py b/tests/test_loss.py index fea39403..aa93048f 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -110,10 +110,10 @@ def test_tanimoto_classification_loss(): loss_func = TanimotoDistLoss() loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) - assert round(float(loss.item()), 3) == 0.389 + assert round(float(loss.item()), 3) == 0.61 loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) - assert round(float(loss.item()), 3) == 0.569 + assert round(float(loss.item()), 3) == 0.431 loss_func = TanimotoComplementLoss() loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) @@ -129,16 +129,16 @@ def test_tanimoto_classification_loss(): ] ) loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) - assert round(float(loss.item()), 3) == 0.606 + assert round(float(loss.item()), 3) == 0.717 loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) - assert round(float(loss.item()), 3) == 0.63 + assert round(float(loss.item()), 3) == 0.561 def test_tanimoto_regression_loss(): loss_func = TanimotoDistLoss(one_hot_targets=False) loss = loss_func(INPUTS_DIST, DIST_TARGETS) - assert round(float(loss.item()), 3) == 0.583 + assert round(float(loss.item()), 3) == 0.417 loss_func = TanimotoComplementLoss(one_hot_targets=False) loss = loss_func(INPUTS_DIST, DIST_TARGETS)