From d8bcfb2c7710aad0e3a821935a99a34dda237173 Mon Sep 17 00:00:00 2001 From: Jordan Graesser Date: Wed, 14 Aug 2024 15:41:10 +1000 Subject: [PATCH] torch20 transfer (#86) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🎨 formatting * 🔥 remove gain from dask storage * 🎨 formatting * 🎨 formatting * ⚡️ change CLI defaults * ✨ new loss enum * ⚡️ relative import * ✅ update tests * 🔒️ upgrade setuptools --- docs/requirements.txt | 2 +- pyproject.toml | 2 +- setup.cfg | 3 +- src/cultionet/data/create.py | 144 +++++++++++++++++++--------- src/cultionet/data/store.py | 4 +- src/cultionet/enums/__init__.py | 1 + src/cultionet/losses/__init__.py | 1 + src/cultionet/losses/losses.py | 154 ++++++++++++++++-------------- src/cultionet/models/lightning.py | 80 ++++++++-------- src/cultionet/scripts/args.yml | 6 +- tests/test_loss.py | 23 ++++- 11 files changed, 248 insertions(+), 172 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index ceb2bdbe..01840c6b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -25,7 +25,7 @@ tensorboard>=2.2.0 PyYAML>=5.1 geowombat@git+https://github.com/jgrss/geowombat.git tsaug@git+https://github.com/jgrss/tsaug.git -setuptools==59.5.0 +setuptools>=70 numpydoc sphinx sphinx-automodapi diff --git a/pyproject.toml b/pyproject.toml index 7ab627ab..639eb319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - 'setuptools>=65.5.1', + 'setuptools>=70', 'wheel', 'numpy<2,>=1.22', ] diff --git a/setup.cfg b/setup.cfg index 58c58220..56c0466e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ package_dir= packages=find: include_package_data = True setup_requires = - setuptools>=65.5.1 + setuptools>=70 wheel numpy<2,>=1.22 python_requires = @@ -63,7 +63,6 @@ install_requires = geowombat@git+https://github.com/jgrss/geowombat.git tsaug@git+https://github.com/jgrss/tsaug.git pygrts@git+https://github.com/jgrss/pygrts.git@v1.4.1 - setuptools>=65.5.1 [options.extras_require] docs = numpydoc diff --git a/src/cultionet/data/create.py b/src/cultionet/data/create.py index 7d81d4f0..962de174 100644 --- a/src/cultionet/data/create.py +++ b/src/cultionet/data/create.py @@ -1,3 +1,4 @@ +import logging import typing as T from pathlib import Path @@ -11,6 +12,7 @@ import torch import xarray as xr from affine import Affine +from dask.diagnostics import ProgressBar from dask.distributed import Client, LocalCluster, progress from rasterio.windows import Window, from_bounds from scipy.ndimage import label as nd_label @@ -71,11 +73,14 @@ def reshape_and_mask_array( num_bands: int, gain: float, offset: int, + apply_gain: bool = True, ) -> xr.DataArray: """Reshapes an array and masks no-data values.""" - src_ts_stack = xr.DataArray( - # Date are stored [(band x time) x height x width] + dtype = 'float32' if apply_gain else 'int16' + + time_series = xr.DataArray( + # Data are stored [(band x time) x height x width] ( data.data.reshape( num_bands, @@ -83,7 +88,7 @@ def reshape_and_mask_array( data.gw.nrows, data.gw.ncols, ).transpose(1, 0, 2, 3) - ).astype('float32'), + ).astype(dtype), dims=('time', 'band', 'y', 'x'), coords={ 'time': range(num_time), @@ -94,12 +99,18 @@ def reshape_and_mask_array( attrs=data.attrs.copy(), ) - with xr.set_options(keep_attrs=True): - time_series = (src_ts_stack.gw.mask_nodata() * gain + offset).fillna(0) + if apply_gain: + + with xr.set_options(keep_attrs=True): + # Mask and scale the data + time_series = ( + time_series.gw.mask_nodata() * gain + offset + ).fillna(0) return time_series +@threadpool_limits.wrap(limits=1, user_api="blas") def create_predict_dataset( image_list: T.List[T.List[T.Union[str, Path]]], region: str, @@ -113,26 +124,36 @@ def create_predict_dataset( padding: int = 101, num_workers: int = 1, compress_method: T.Union[int, str] = 'zlib', + use_cluster: bool = True, ): """Creates a prediction dataset for an image.""" + # Read windows larger than the re-chunk window size + read_chunksize = 1024 + while True: + if read_chunksize < window_size: + read_chunksize *= 2 + else: + break + with gw.config.update(ref_res=ref_res): with gw.open( image_list, stack_dim="band", band_names=list(range(1, len(image_list) + 1)), resampling=resampling, - chunks=512, + chunks=read_chunksize, ) as src_ts: - + # Get the time and band count num_time, num_bands = get_image_list_dims(image_list, src_ts) - time_series: xr.DataArray = reshape_and_mask_array( + time_series = reshape_and_mask_array( data=src_ts, num_time=num_time, num_bands=num_bands, gain=gain, offset=offset, + apply_gain=False, ) # Chunk the array into the windows @@ -172,42 +193,77 @@ def create_predict_dataset( trim=False, ) - with dask.config.set( - { - "distributed.worker.memory.terminate": False, - "distributed.comm.retry.count": 10, - "distributed.comm.timeouts.connect": 5, - "distributed.scheduler.allowed-failures": 20, - } - ): - with LocalCluster( - processes=True, - n_workers=num_workers, - threads_per_worker=1, - memory_target_fraction=0.97, - memory_limit="4GB", # per worker limit - ) as cluster: - with Client(cluster) as client: - with BatchStore( - data=time_series, - write_path=process_path, - res=ref_res, - resampling=resampling, - region=region, - start_date=pd.to_datetime( - Path(image_list[0]).stem, format=date_format - ).strftime("%Y%m%d"), - end_date=pd.to_datetime( - Path(image_list[-1]).stem, format=date_format - ).strftime("%Y%m%d"), - window_size=window_size, - padding=padding, - compress_method=compress_method, - gain=gain, - ) as batch_store: - save_tasks = batch_store.save(time_series_array) - results = client.persist(save_tasks) - progress(results) + if use_cluster: + with dask.config.set( + { + "distributed.worker.memory.terminate": False, + "distributed.comm.retry.count": 10, + "distributed.comm.timeouts.connect": 5, + "distributed.scheduler.allowed-failures": 20, + "distributed.worker.memory.pause": 0.95, + "distributed.worker.memory.target": 0.97, + "distributed.worker.memory.spill": False, + "distributed.scheduler.worker-saturation": 1.0, + } + ): + with LocalCluster( + processes=True, + n_workers=num_workers, + threads_per_worker=1, + memory_limit="6GB", # per worker limit + silence_logs=logging.ERROR, + ) as cluster: + with Client(cluster) as client: + with BatchStore( + data=time_series, + write_path=process_path, + res=ref_res, + resampling=resampling, + region=region, + start_date=pd.to_datetime( + Path(image_list[0]).stem, + format=date_format, + ).strftime("%Y%m%d"), + end_date=pd.to_datetime( + Path(image_list[-1]).stem, + format=date_format, + ).strftime("%Y%m%d"), + window_size=window_size, + padding=padding, + compress_method=compress_method, + ) as batch_store: + save_tasks = batch_store.save( + time_series_array + ) + results = client.gather( + client.persist(save_tasks) + ) + progress(results) + + else: + + with dask.config.set( + scheduler='processes', num_workers=num_workers + ): + with BatchStore( + data=time_series, + write_path=process_path, + res=ref_res, + resampling=resampling, + region=region, + start_date=pd.to_datetime( + Path(image_list[0]).stem, format=date_format + ).strftime("%Y%m%d"), + end_date=pd.to_datetime( + Path(image_list[-1]).stem, format=date_format + ).strftime("%Y%m%d"), + window_size=window_size, + padding=padding, + compress_method=compress_method, + ) as batch_store: + save_tasks = batch_store.save(time_series_array) + with ProgressBar(): + save_tasks.compute() class ReferenceArrays: diff --git a/src/cultionet/data/store.py b/src/cultionet/data/store.py index 582c7cc3..b598b472 100644 --- a/src/cultionet/data/store.py +++ b/src/cultionet/data/store.py @@ -31,7 +31,6 @@ def __init__( window_size: int, padding: int, compress_method: Union[int, str], - gain: float, ): self.data = data self.res = res @@ -43,7 +42,6 @@ def __init__( self.window_size = window_size self.padding = padding self.compress_method = compress_method - self.gain = gain def __setitem__(self, key: tuple, item: np.ndarray) -> None: time_range, index_range, y, x = key @@ -87,7 +85,7 @@ def write_batch(self, x: np.ndarray, w: Window, w_pad: Window): ) x = einops.rearrange( - torch.from_numpy(x / self.gain).to(dtype=torch.int32), + torch.from_numpy(x.astype('int32')).to(dtype=torch.int32), 't c h w -> 1 c t h w', ) diff --git a/src/cultionet/enums/__init__.py b/src/cultionet/enums/__init__.py index 666246c1..af5a975c 100644 --- a/src/cultionet/enums/__init__.py +++ b/src/cultionet/enums/__init__.py @@ -54,6 +54,7 @@ class LossTypes(StrEnum): CLASS_BALANCED_MSE = "ClassBalancedMSELoss" TANIMOTO_COMPLEMENT = "TanimotoComplementLoss" TANIMOTO = "TanimotoDistLoss" + TANIMOTO_COMBINED = "TanimotoCombined" TOPOLOGY = "TopologyLoss" diff --git a/src/cultionet/losses/__init__.py b/src/cultionet/losses/__init__.py index 26a8bdc4..d11d04ff 100644 --- a/src/cultionet/losses/__init__.py +++ b/src/cultionet/losses/__init__.py @@ -1,6 +1,7 @@ from .losses import ( BoundaryLoss, ClassBalancedMSELoss, + CombinedLoss, LossPreprocessing, TanimotoComplementLoss, TanimotoDistLoss, diff --git a/src/cultionet/losses/losses.py b/src/cultionet/losses/losses.py index 4216dd7a..9bd84502 100644 --- a/src/cultionet/losses/losses.py +++ b/src/cultionet/losses/losses.py @@ -1,5 +1,4 @@ import typing as T -import warnings import einops import torch @@ -23,7 +22,10 @@ def __init__( self.one_hot_targets = one_hot_targets def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, ) -> T.Tuple[torch.Tensor, torch.Tensor]: """Forward pass to transform logits. @@ -47,9 +49,50 @@ def forward( else: targets = einops.rearrange(targets, 'b h w -> b 1 h w') + if mask is not None: + # Apply a mask to zero-out weight + inputs = inputs * mask + targets = targets * mask + return inputs, targets +class CombinedLoss(nn.Module): + def __init__(self, losses: T.List[T.Callable]): + super().__init__() + + self.losses = losses + + def forward( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Performs a single forward pass. + + Args: + inputs: Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets: Ground truth values, shaped (B, C, H, W). + mask: Values to mask (0) or keep (1), shaped (B, 1, H, W). + + Returns: + Average distance loss (float) + """ + + loss = 0.0 + for loss_func in self.losses: + loss = loss + loss_func( + inputs=inputs, + targets=targets, + mask=mask, + ) + + loss = loss / len(self.losses) + + return loss + + class TanimotoComplementLoss(nn.Module): """Tanimoto distance loss. @@ -102,37 +145,27 @@ def tanimoto_distance( self, y: torch.Tensor, yhat: torch.Tensor, - mask: T.Optional[torch.Tensor] = None, - weights: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: scale = 1.0 / self.depth - if mask is not None: - y = y * mask - yhat = yhat * mask - tpl = y * yhat sq_sum = y**2 + yhat**2 - tpl = tpl.sum(dim=(2, 3)) - sq_sum = sq_sum.sum(dim=(2, 3)) - - if weights is not None: - tpl = tpl * weights - sq_sum = sq_sum * weights + tpl = tpl.sum(dim=(1, 2, 3)) + sq_sum = sq_sum.sum(dim=(1, 2, 3)) denominator = 0.0 for d in range(0, self.depth): a = 2.0**d b = -(2.0 * a - 1.0) denominator = denominator + torch.reciprocal( - (a * sq_sum) + (b * tpl) - ) - denominator = torch.nan_to_num( - denominator, nan=0.0, posinf=0.0, neginf=0.0 + ((a * sq_sum) + (b * tpl)) + self.smooth ) - return ((tpl * denominator) * scale).sum(dim=1) + numerator = tpl + self.smooth + distance = (numerator * denominator) * scale + + return 1.0 - distance def forward( self, @@ -143,19 +176,20 @@ def forward( """Performs a single forward pass. Args: - inputs: Predictions from model (probabilities or labels). - targets: Ground truth values. + inputs: Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets: Ground truth values, shaped (B, C, H, W). + mask: Values to mask (0) or keep (1), shaped (B, 1, H, W). Returns: Tanimoto distance loss (float) """ - inputs, targets = self.preprocessor(inputs, targets) - - loss = 1.0 - self.tanimoto_distance(targets, inputs, mask=mask) - compl_loss = 1.0 - self.tanimoto_distance( - 1.0 - targets, 1.0 - inputs, mask=mask + inputs, targets = self.preprocessor( + inputs=inputs, targets=targets, mask=mask ) - loss = (loss + compl_loss) * 0.5 + + loss1 = self.tanimoto_distance(targets, inputs) + loss2 = self.tanimoto_distance(1.0 - targets, 1.0 - inputs) + loss = (loss1 + loss2) * 0.5 return loss.mean() @@ -163,14 +197,11 @@ def forward( def tanimoto_dist( ypred: torch.Tensor, ytrue: torch.Tensor, - scale_pos_weight: bool, - class_counts: T.Union[None, torch.Tensor], - beta: float, smooth: float, - mask: T.Optional[torch.Tensor] = None, weights: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Tanimoto distance.""" + ytrue = ytrue.to(dtype=ypred.dtype) # Take the batch mean of the channel sums @@ -187,20 +218,6 @@ def tanimoto_dist( batch_weight, ) - if scale_pos_weight: - if class_counts is None: - class_counts = ytrue.sum(dim=0) - else: - class_counts = class_counts - effective_num = 1.0 - beta**class_counts - weights = (1.0 - beta) / effective_num - weights = weights / weights.sum() * class_counts.shape[0] - - # Apply a mask to zero-out gradients where mask == 0 - if mask is not None: - ytrue = ytrue * mask - ypred = ypred * mask - tpl = ypred * ytrue sq_sum = ypred**2 + ytrue**2 @@ -208,14 +225,18 @@ def tanimoto_dist( tpl = tpl.sum(dim=(2, 3)) sq_sum = sq_sum.sum(dim=(2, 3)) - if weights is not None: - tpl = tpl * weights - sq_sum = sq_sum * weights - - numerator = (tpl * batch_weight + smooth).sum(dim=1) - denominator = ((sq_sum - tpl) * batch_weight + smooth).sum(dim=1) + numerator = (tpl * batch_weight) + smooth + denominator = ((sq_sum - tpl) * batch_weight) + smooth distance = numerator / denominator + loss = 1.0 - distance + + # Apply weights + if weights is not None: + loss = (loss * weights).sum(dim=1) / weights.sum() + else: + loss = loss.mean(dim=1) + return distance @@ -262,22 +283,14 @@ def __init__( smooth: float = 1e-5, beta: T.Optional[float] = 0.999, class_counts: T.Optional[torch.Tensor] = None, - scale_pos_weight: bool = False, transform_logits: bool = False, one_hot_targets: bool = True, ): super().__init__() - if scale_pos_weight and (class_counts is None): - warnings.warn( - "Cannot balance classes without class weights. Weights will be derived for each batch.", - UserWarning, - ) - self.smooth = smooth self.beta = beta self.class_counts = class_counts - self.scale_pos_weight = scale_pos_weight self.preprocessor = LossPreprocessing( transform_logits=transform_logits, @@ -293,34 +306,29 @@ def forward( """Performs a single forward pass. Args: - inputs: Predictions from model (probabilities, logits or labels). - targets: Ground truth values. + inputs: Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets: Ground truth values, shaped (B, C, H, W). + mask: Values to mask (0) or keep (1), shaped (B, 1, H, W). Returns: Tanimoto distance loss (float) """ - inputs, targets = self.preprocessor(inputs, targets) + inputs, targets = self.preprocessor( + inputs=inputs, targets=targets, mask=mask + ) - loss = 1.0 - tanimoto_dist( + loss1 = tanimoto_dist( inputs, targets, - scale_pos_weight=self.scale_pos_weight, - class_counts=self.class_counts, - beta=self.beta, smooth=self.smooth, - mask=mask, ) - compl_loss = 1.0 - tanimoto_dist( + loss2 = tanimoto_dist( 1.0 - inputs, 1.0 - targets, - scale_pos_weight=self.scale_pos_weight, - class_counts=self.class_counts, - beta=self.beta, smooth=self.smooth, - mask=mask, ) - loss = (loss + compl_loss) * 0.5 + loss = (loss1 + loss2) * 0.5 return loss.mean() diff --git a/src/cultionet/models/lightning.py b/src/cultionet/models/lightning.py index a5a6c2b2..57251e82 100644 --- a/src/cultionet/models/lightning.py +++ b/src/cultionet/models/lightning.py @@ -150,7 +150,7 @@ def get_true_labels( mask = None if batch.y.min() == -1: mask = torch.where(batch.y == -1, 0, 1).to( - dtype=torch.uint8, device=batch.y.device + dtype=torch.long, device=batch.y.device ) mask = einops.rearrange(mask, 'b h w -> b 1 h w') @@ -163,14 +163,6 @@ def get_true_labels( "mask": mask, } - # def on_validation_epoch_end(self, *args, **kwargs): - # """Save the model on validation end.""" - # if self.logger.save_dir is not None: - # model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt" - # if model_file.is_file(): - # model_file.unlink() - # torch.save(self.state_dict(), model_file) - def calc_loss( self, batch: T.Union[Data, T.List], @@ -322,27 +314,6 @@ def calc_loss( weights["crop_cmse_loss"] = 0.1 loss = loss + crop_cmse_loss * weights["crop_cmse_loss"] - # Topology loss - # topo_loss = self.topo_loss( - # predictions["edge"].squeeze(dim=1), - # true_labels_dict["true_edge"], - # ) - # weights["topo_loss"] = 0.1 - # loss = loss + topo_loss * weights["topo_loss"] - - # if predictions["crop_type"] is not None: - # # Upstream (deep) loss on crop-type - # crop_type_star_loss = self.crop_type_star_loss( - # predictions["crop_type_star"], - # true_labels_dict["true_crop_type"], - # ) - # loss = loss + crop_type_star_loss - # # Loss on crop-type - # crop_type_loss = self.crop_type_loss( - # predictions["crop_type"], true_labels_dict["true_crop_type"] - # ) - # loss = loss + crop_type_loss - return loss / sum(weights.values()) def mask_rcnn_forward( @@ -897,6 +868,22 @@ def __init__( one_hot_targets=False ), }, + LossTypes.TANIMOTO_COMBINED: { + "classification": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss(), + cnetlosses.TanimotoComplementLoss(), + ], + ), + "regression": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss(one_hot_targets=False), + cnetlosses.TanimotoComplementLoss( + one_hot_targets=False + ), + ], + ), + }, LossTypes.TOPOLOGY: { "classification": cnetlosses.TopologyLoss(), }, @@ -906,17 +893,10 @@ def __init__( checkpoint_path=str(pretrained_ckpt_file) ).cultionet_model - # import torchinfo - # torchinfo.summary( - # model=self.cultionet_model.mask_model, - # input_size=[(1, 5, 13, 100, 100), (1, 64, 100, 100)], - # device="cuda", - # ) - - # Freeze all parameters if not finetuning the full model if self.finetune != "all": - for name, param in self.cultionet_model.named_parameters(): - param.requires_grad = False + + # Freeze all parameters if not finetuning the full model + self.freeze(self.cultionet_model) if self.finetune == "fc": # Unfreeze fully connected layers @@ -988,6 +968,10 @@ def __init__( def is_transfer_model(self) -> bool: return True + def freeze(self, layer): + for param in layer.parameters(): + param.requires_grad = False + def unfreeze(self, layer): for param in layer.parameters(): param.requires_grad = True @@ -1076,6 +1060,22 @@ def __init__( one_hot_targets=False ), }, + LossTypes.TANIMOTO_COMBINED: { + "classification": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss(), + cnetlosses.TanimotoComplementLoss(), + ], + ), + "regression": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss(one_hot_targets=False), + cnetlosses.TanimotoComplementLoss( + one_hot_targets=False + ), + ], + ), + }, LossTypes.TOPOLOGY: { "classification": cnetlosses.TopologyLoss(), }, diff --git a/src/cultionet/scripts/args.yml b/src/cultionet/scripts/args.yml index 8e12c157..39d5c798 100644 --- a/src/cultionet/scripts/args.yml +++ b/src/cultionet/scripts/args.yml @@ -461,8 +461,8 @@ train: long: loss-name help: The loss method name kwargs: - default: 'TanimotoComplementLoss' - choices: ['TanimotoDistLoss', 'TanimotoComplementLoss'] + default: 'TanimotoCombined' + choices: ['TanimotoDistLoss', 'TanimotoComplementLoss', 'TanimotoCombined'] learning_rate: short: lr long: learning-rate @@ -609,7 +609,7 @@ predict: long: padding help: The read padding around the window (padding is sliced off before writing) kwargs: - default: 101 + default: 20 type: '&int' mode: short: '' diff --git a/tests/test_loss.py b/tests/test_loss.py index 6efbd02a..fea39403 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -5,6 +5,7 @@ from einops import rearrange from cultionet.losses import ( + CombinedLoss, LossPreprocessing, TanimotoComplementLoss, TanimotoDistLoss, @@ -109,23 +110,35 @@ def test_tanimoto_classification_loss(): loss_func = TanimotoDistLoss() loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) - assert round(float(loss.item()), 3) == 0.61 + assert round(float(loss.item()), 3) == 0.389 loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) - assert round(float(loss.item()), 3) == 0.608 + assert round(float(loss.item()), 3) == 0.569 loss_func = TanimotoComplementLoss() loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) - assert round(float(loss.item()), 3) == 0.649 + assert round(float(loss.item()), 3) == 0.824 loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) - assert round(float(loss.item()), 3) == 0.647 + assert round(float(loss.item()), 3) == 0.692 + + loss_func = CombinedLoss( + losses=[ + TanimotoDistLoss(), + TanimotoComplementLoss(), + ] + ) + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) + assert round(float(loss.item()), 3) == 0.606 + + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) + assert round(float(loss.item()), 3) == 0.63 def test_tanimoto_regression_loss(): loss_func = TanimotoDistLoss(one_hot_targets=False) loss = loss_func(INPUTS_DIST, DIST_TARGETS) - assert round(float(loss.item()), 3) == 0.417 + assert round(float(loss.item()), 3) == 0.583 loss_func = TanimotoComplementLoss(one_hot_targets=False) loss = loss_func(INPUTS_DIST, DIST_TARGETS)