Skip to content

Commit

Permalink
fix: torch20 ex (#87)
Browse files Browse the repository at this point in the history
* ⚡️ sum weights

* ⚡️ move sum

* ✅ update loss tests

* 🎨 formatting

* ⚡️ use common method

* 🎨 formatting

* ➖ make kornia optional

* ➕ increase dependency versions
  • Loading branch information
jgrss authored Aug 15, 2024
1 parent d8bcfb2 commit 6ef9b9c
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 143 deletions.
8 changes: 5 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ python_requires =
>=3.9,<3.11
install_requires =
attrs>=21
dask>=2024.8.0
distributed>=2024.8.0
xarray>=2024.7.0
frozendict>=2.2
frozenlist>=1.3
numpy<2,>=1.22
Expand All @@ -45,21 +48,20 @@ install_requires =
decorator==4.4.2
rtree>=0.9.7
graphviz>=0.19
tqdm>=4.62
tqdm>=4.66
pyDeprecate==0.3.1
future>=0.17.1
tensorboard>=2.2
PyYAML>=5.1
lightning>=2.2
torchmetrics>=1.3
einops>=0.7
ray<=2.1,>=2
ray>=2.34
pyarrow>=11
typing-extensions
lz4
rich-argparse
pyogrio>=0.7
kornia>=0.7
geowombat@git+https://github.com/jgrss/geowombat.git
tsaug@git+https://github.com/jgrss/tsaug.git
pygrts@git+https://github.com/jgrss/[email protected]
Expand Down
213 changes: 97 additions & 116 deletions src/cultionet/data/create.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import logging
import typing as T
from pathlib import Path

import dask
import dask.array as da
import einops
import geopandas as gpd
import geowombat as gw
import numpy as np
import pandas as pd
import psutil
import ray
import torch
import xarray as xr
from affine import Affine
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster, progress
from psutil._common import bytes2human
from rasterio.windows import Window, from_bounds
from ray.exceptions import RayTaskError
from ray.util.dask import enable_dask_on_ray, ray_dask_get
from scipy.ndimage import label as nd_label
from skimage.measure import regionprops
from threadpoolctl import threadpool_limits
Expand Down Expand Up @@ -110,7 +112,6 @@ def reshape_and_mask_array(
return time_series


@threadpool_limits.wrap(limits=1, user_api="blas")
def create_predict_dataset(
image_list: T.List[T.List[T.Union[str, Path]]],
region: str,
Expand All @@ -121,130 +122,102 @@ def create_predict_dataset(
ref_res: T.Union[float, T.Tuple[float, float]] = 10.0,
resampling: str = "nearest",
window_size: int = 100,
padding: int = 101,
padding: int = 20,
num_workers: int = 1,
compress_method: T.Union[int, str] = 'zlib',
use_cluster: bool = True,
):
"""Creates a prediction dataset for an image."""

# Read windows larger than the re-chunk window size
read_chunksize = 1024
read_chunksize = 256
while True:
if read_chunksize < window_size:
if read_chunksize < window_size + padding:
read_chunksize *= 2
else:
break

with gw.config.update(ref_res=ref_res):
with gw.open(
image_list,
stack_dim="band",
band_names=list(range(1, len(image_list) + 1)),
resampling=resampling,
chunks=read_chunksize,
) as src_ts:
# Get the time and band count
num_time, num_bands = get_image_list_dims(image_list, src_ts)

time_series = reshape_and_mask_array(
data=src_ts,
num_time=num_time,
num_bands=num_bands,
gain=gain,
offset=offset,
apply_gain=False,
)
total_cpus = psutil.cpu_count(logical=True)
threads_per_worker = total_cpus // num_workers

# Chunk the array into the windows
time_series_array = time_series.chunk(
{"time": -1, "band": -1, "y": window_size, "x": window_size}
).data

# Check if the array needs to be padded
# First, get the end chunk size of rows and columns
height_end_chunk = time_series_array.chunks[-2][-1]
width_end_chunk = time_series_array.chunks[-1][-1]

height_padding = 0
width_padding = 0
if padding > height_end_chunk:
height_padding = padding - height_end_chunk
if padding > width_end_chunk:
width_padding = padding - width_end_chunk

if (height_padding > 0) or (width_padding > 0):
# Pad the full array if the end chunk is smaller than the padding
time_series_array = da.pad(
time_series_array,
pad_width=(
(0, 0),
(0, 0),
(0, height_padding),
(0, width_padding),
),
).rechunk({0: -1, 1: -1, 2: window_size, 3: window_size})

# Add the padding to each chunk
time_series_array = time_series_array.map_overlap(
lambda x: x,
depth={0: 0, 1: 0, 2: padding, 3: padding},
boundary=0,
trim=False,
)
logger.info(f"Opening images with window chunk sizes of {read_chunksize}.")
logger.info(
f"Re-chunking image arrays to chunk sizes of {window_size} with padding of {padding}."
)
logger.info(
f"Virtual memory available is {bytes2human(psutil.virtual_memory().available)}."
)
logger.info(
f"Creating PyTorch dataset with {num_workers} processes and {threads_per_worker} threads."
)

with threadpool_limits(limits=threads_per_worker, user_api="blas"):

with gw.config.update(ref_res=ref_res):
with gw.open(
image_list,
stack_dim="band",
band_names=list(range(1, len(image_list) + 1)),
resampling=resampling,
chunks=read_chunksize,
) as src_ts:
# Get the time and band count
num_time, num_bands = get_image_list_dims(image_list, src_ts)

if use_cluster:
with dask.config.set(
time_series = reshape_and_mask_array(
data=src_ts,
num_time=num_time,
num_bands=num_bands,
gain=gain,
offset=offset,
apply_gain=False,
)

# Chunk the array into the windows
time_series_array = time_series.chunk(
{
"distributed.worker.memory.terminate": False,
"distributed.comm.retry.count": 10,
"distributed.comm.timeouts.connect": 5,
"distributed.scheduler.allowed-failures": 20,
"distributed.worker.memory.pause": 0.95,
"distributed.worker.memory.target": 0.97,
"distributed.worker.memory.spill": False,
"distributed.scheduler.worker-saturation": 1.0,
"time": -1,
"band": -1,
"y": window_size,
"x": window_size,
}
):
with LocalCluster(
processes=True,
n_workers=num_workers,
threads_per_worker=1,
memory_limit="6GB", # per worker limit
silence_logs=logging.ERROR,
) as cluster:
with Client(cluster) as client:
with BatchStore(
data=time_series,
write_path=process_path,
res=ref_res,
resampling=resampling,
region=region,
start_date=pd.to_datetime(
Path(image_list[0]).stem,
format=date_format,
).strftime("%Y%m%d"),
end_date=pd.to_datetime(
Path(image_list[-1]).stem,
format=date_format,
).strftime("%Y%m%d"),
window_size=window_size,
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(
time_series_array
)
results = client.gather(
client.persist(save_tasks)
)
progress(results)

else:

with dask.config.set(
scheduler='processes', num_workers=num_workers
):
).data

# Check if the array needs to be padded
# First, get the end chunk size of rows and columns
height_end_chunk = time_series_array.chunks[-2][-1]
width_end_chunk = time_series_array.chunks[-1][-1]

height_padding = 0
width_padding = 0
if padding > height_end_chunk:
height_padding = padding - height_end_chunk
if padding > width_end_chunk:
width_padding = padding - width_end_chunk

if (height_padding > 0) or (width_padding > 0):
# Pad the full array if the end chunk is smaller than the padding
time_series_array = da.pad(
time_series_array,
pad_width=(
(0, 0),
(0, 0),
(0, height_padding),
(0, width_padding),
),
).rechunk({0: -1, 1: -1, 2: window_size, 3: window_size})

# Add the padding to each chunk
time_series_array = time_series_array.map_overlap(
lambda x: x,
depth={0: 0, 1: 0, 2: padding, 3: padding},
boundary=0,
trim=False,
)

if not ray.is_initialized():
ray.init(num_cpus=num_workers)

try:
with BatchStore(
data=time_series,
write_path=process_path,
Expand All @@ -261,9 +234,17 @@ def create_predict_dataset(
padding=padding,
compress_method=compress_method,
) as batch_store:
save_tasks = batch_store.save(time_series_array)
with ProgressBar():
save_tasks.compute()
batch_store.save(
time_series_array,
scheduler=ray_dask_get,
)

except RayTaskError as e:
logger.warning(e)
ray.shutdown()

if ray.is_initialized():
ray.shutdown()


class ReferenceArrays:
Expand Down
4 changes: 2 additions & 2 deletions src/cultionet/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def split_train_val(
return train_ds, val_ds

def load_file(self, filename: T.Union[str, Path]) -> Data:
return joblib.load(filename)
return Data.from_file(filename)

def __getitem__(
self, idx: T.Union[int, np.ndarray]
Expand All @@ -400,7 +400,7 @@ def get(self, idx: int) -> dict:
idx (int): The dataset index position.
"""

batch = Data.from_file(self.data_list_[idx])
batch = self.load_file(self.data_list_[idx])

batch.x = (batch.x * 1e-4).clip(1e-9, 1)

Expand Down
15 changes: 13 additions & 2 deletions src/cultionet/data/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
from dask.delayed import Delayed
from dask.utils import SerializableLock
from rasterio.windows import Window
from retry import retry

from ..utils.logging import set_color_logger
from .data import Data

logger = set_color_logger(__name__)


class BatchStore:
"""``dask.array.store`` for data batches."""
Expand Down Expand Up @@ -61,6 +65,7 @@ def __setitem__(self, key: tuple, item: np.ndarray) -> None:

self.write_batch(item, w=item_window, w_pad=pad_window)

@retry(IOError, tries=5, delay=1)
def write_batch(self, x: np.ndarray, w: Window, w_pad: Window):
image_height = self.window_size + self.padding * 2
image_width = self.window_size + self.padding * 2
Expand Down Expand Up @@ -133,8 +138,14 @@ def write_batch(self, x: np.ndarray, w: Window, w_pad: Window):
compress=self.compress_method,
)

try:
_ = batch.from_file(self.write_path / f"{batch_id}.pt")
except EOFError:
raise IOError

def __enter__(self) -> "BatchStore":
self.closed = False

return self

def __exit__(self, exc_type, exc_value, traceback):
Expand All @@ -143,5 +154,5 @@ def __exit__(self, exc_type, exc_value, traceback):
def _open(self) -> "BatchStore":
return self

def save(self, data: da.Array) -> Delayed:
return da.store(data, self, lock=self.lock_, compute=False)
def save(self, data: da.Array, **kwargs) -> Delayed:
da.store(data, self, lock=self.lock_, compute=True, **kwargs)
Loading

0 comments on commit 6ef9b9c

Please sign in to comment.