From 024bd44b9ff9830357ae95d10be1f32c2d062fae Mon Sep 17 00:00:00 2001 From: jgrss Date: Fri, 3 May 2024 02:03:41 +0000 Subject: [PATCH] fix dropout --- src/cultionet/model.py | 27 ++++++---- src/cultionet/models/cultionet.py | 44 +++++------------ src/cultionet/models/lightning.py | 48 +++++++++--------- src/cultionet/models/nunet.py | 48 +++++++++++++++--- src/cultionet/nn/modules/convolution.py | 66 +++++++++++++------------ src/cultionet/scripts/args.yml | 18 ++++++- src/cultionet/scripts/cultionet.py | 2 + 7 files changed, 146 insertions(+), 107 deletions(-) diff --git a/src/cultionet/model.py b/src/cultionet/model.py index 78a88426..60acce97 100644 --- a/src/cultionet/model.py +++ b/src/cultionet/model.py @@ -13,6 +13,7 @@ ModelPruning, StochasticWeightAveraging, ) +from lightning.pytorch.tuner import Tuner from rasterio.windows import Window from scipy.stats import mode as sci_mode from torchvision import transforms @@ -243,13 +244,13 @@ def get_data_module( def setup_callbacks( ckpt_file: T.Union[str, Path], - save_top_k: T.Optional[int] = 1, - early_stopping_min_delta: T.Optional[float] = 0.01, - early_stopping_patience: T.Optional[int] = 7, - stochastic_weight_averaging: T.Optional[bool] = False, - stochastic_weight_averaging_lr: T.Optional[float] = 0.05, - stochastic_weight_averaging_start: T.Optional[float] = 0.8, - model_pruning: T.Optional[bool] = False, + save_top_k: int = 1, + early_stopping_min_delta: float = 0.01, + early_stopping_patience: int = 7, + stochastic_weight_averaging: bool = False, + stochastic_weight_averaging_lr: float = 0.05, + stochastic_weight_averaging_start: float = 0.8, + model_pruning: bool = False, ) -> T.Tuple[LearningRateMonitor, T.Sequence[T.Any]]: # Checkpoint cb_train_loss = ModelCheckpoint(monitor="loss") @@ -257,7 +258,7 @@ def setup_callbacks( cb_val_loss = ModelCheckpoint( dirpath=ckpt_file.parent, filename=ckpt_file.stem, - save_last=True, + save_last=False, save_top_k=save_top_k, mode="min", monitor="val_score", @@ -578,6 +579,7 @@ def fit( class_counts: T.Sequence[float] = None, model_type: str = ModelTypes.RESUNET3PSI, activation_type: str = "SiLU", + dropout: float = 0.0, dilations: T.Union[int, T.Sequence[int]] = None, res_block_type: str = ResBlockTypes.RES, attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, @@ -608,6 +610,7 @@ def fit( skip_train: bool = False, refine_model: bool = False, finetune: bool = False, + strategy: str = "ddp", ): """Fits a model. @@ -661,6 +664,7 @@ def fit( skip_train (Optional[bool]): Whether to refine and calibrate a trained model. refine_model (Optional[bool]): Whether to skip training. finetune (bool): Not used. Placeholder for compatibility with transfer learning. + strategy (str): The model distributed strategy. """ ckpt_file = Path(ckpt_file) @@ -727,7 +731,7 @@ def fit( precision=precision, devices=devices, accelerator=device, - strategy='ddp', + strategy=strategy, log_every_n_steps=50, profiler=profiler, deterministic=False, @@ -735,7 +739,10 @@ def fit( ) if auto_lr_find: - trainer.tune(model=lit_model, datamodule=data_module) + tuner = Tuner(trainer) + lr_finder = tuner.lr_find(model=lit_model, datamodule=data_module) + opt_lr = lr_finder.suggestion() + logger.info(f"The suggested learning rate is {opt_lr}") else: if not skip_train: trainer.fit( diff --git a/src/cultionet/models/cultionet.py b/src/cultionet/models/cultionet.py index 0a57c213..9c593326 100644 --- a/src/cultionet/models/cultionet.py +++ b/src/cultionet/models/cultionet.py @@ -273,6 +273,7 @@ def __init__( num_classes: int = 2, model_type: str = ModelTypes.TOWERUNET, activation_type: str = "SiLU", + dropout: float = 0.1, dilations: T.Union[int, T.Sequence[int]] = None, res_block_type: str = ResBlockTypes.RES, attention_weights: str = "spatial_channel", @@ -290,7 +291,7 @@ def __init__( hidden_channels=self.hidden_channels, num_head=8, in_time=self.in_time, - dropout=0.1, + dropout=0.2, num_layers=2, d_model=128, time_scaler=100, @@ -305,6 +306,10 @@ def __init__( "in_time": self.in_time, "hidden_channels": self.hidden_channels, "num_classes": self.num_classes, + "attention_weights": attention_weights, + "res_block_type": res_block_type, + "dropout": dropout, + "dilations": dilations, "activation_type": activation_type, "deep_supervision": deep_supervision, "mask_activation": nn.Softmax(dim=1), @@ -315,40 +320,13 @@ def __init__( ModelTypes.RESUNET3PSI, ModelTypes.TOWERUNET, ), "The model type is not supported." + if model_type == ModelTypes.UNET3PSI: - unet3_kwargs["dilation"] = 2 if dilations is None else dilations - assert isinstance( - unet3_kwargs["dilation"], int - ), f"The dilation for {ModelTypes.UNET3PSI} must be an integer." self.mask_model = UNet3Psi(**unet3_kwargs) - elif model_type in ( - ModelTypes.RESUNET3PSI, - ModelTypes.TOWERUNET, - ): - # ResUNet3Psi - unet3_kwargs["attention_weights"] = ( - None if attention_weights == "none" else attention_weights - ) - unet3_kwargs["res_block_type"] = res_block_type - if res_block_type == ResBlockTypes.RES: - unet3_kwargs["dilations"] = ( - [2] if dilations is None else dilations - ) - assert ( - len(unet3_kwargs["dilations"]) == 1 - ), f"The dilations for {ModelTypes.RESUNET3PSI} must be a length-1 integer sequence." - elif res_block_type == ResBlockTypes.RESA: - unet3_kwargs["dilations"] = ( - [1, 2] if dilations is None else dilations - ) - assert isinstance( - unet3_kwargs["dilations"], list - ), f"The dilations for {ModelTypes.RESUNET3PSI} must be a sequence of integers." - - if model_type == ModelTypes.RESUNET3PSI: - self.mask_model = ResUNet3Psi(**unet3_kwargs) - else: - self.mask_model = TowerUNet(**unet3_kwargs) + elif model_type == ModelTypes.RESUNET3PSI: + self.mask_model = ResUNet3Psi(**unet3_kwargs) + else: + self.mask_model = TowerUNet(**unet3_kwargs) def forward(self, batch: Data) -> T.Dict[str, torch.Tensor]: # Transformer attention encoder diff --git a/src/cultionet/models/lightning.py b/src/cultionet/models/lightning.py index fb6d5613..d6bd0f77 100644 --- a/src/cultionet/models/lightning.py +++ b/src/cultionet/models/lightning.py @@ -17,7 +17,7 @@ from ..data.data import Data from ..enums import LearningRateSchedulers, ModelTypes, ResBlockTypes from ..layers.weights import init_attention_weights -from ..losses import FieldOfJunctionsLoss, TanimotoDistLoss +from ..losses import TanimotoComplementLoss, TanimotoDistLoss from .cultionet import CultioNet, GeoRefinement from .maskcrnn import BFasterRCNN from .nunet import PostUNet3Psi @@ -523,13 +523,13 @@ def get_true_labels( "true_crop_type": true_crop_type, } - def on_validation_epoch_end(self, *args, **kwargs): - """Save the model on validation end.""" - if self.logger.save_dir is not None: - model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt" - if model_file.is_file(): - model_file.unlink() - torch.save(self.state_dict(), model_file) + # def on_validation_epoch_end(self, *args, **kwargs): + # """Save the model on validation end.""" + # if self.logger.save_dir is not None: + # model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt" + # if model_file.is_file(): + # model_file.unlink() + # torch.save(self.state_dict(), model_file) def calc_loss( self, @@ -846,30 +846,34 @@ def configure_scorer(self): def configure_loss(self): # Distance loss - self.dist_loss = TanimotoDistLoss(one_hot_targets=False) + self.dist_loss = TanimotoComplementLoss(one_hot_targets=False) # Edge loss - self.edge_loss = TanimotoDistLoss() + self.edge_loss = TanimotoComplementLoss() # Crop mask loss - self.crop_loss = TanimotoDistLoss() + self.crop_loss = TanimotoComplementLoss() # Field of junctions loss - self.foj_loss = FieldOfJunctionsLoss() + # self.foj_loss = FieldOfJunctionsLoss() if self.deep_supervision: - self.dist_loss_deep_b = TanimotoDistLoss(one_hot_targets=False) - self.edge_loss_deep_b = TanimotoDistLoss() - self.crop_loss_deep_b = TanimotoDistLoss() - self.dist_loss_deep_c = TanimotoDistLoss(one_hot_targets=False) - self.edge_loss_deep_c = TanimotoDistLoss() - self.crop_loss_deep_c = TanimotoDistLoss() + self.dist_loss_deep_b = TanimotoComplementLoss( + one_hot_targets=False + ) + self.edge_loss_deep_b = TanimotoComplementLoss() + self.crop_loss_deep_b = TanimotoComplementLoss() + self.dist_loss_deep_c = TanimotoComplementLoss( + one_hot_targets=False + ) + self.edge_loss_deep_c = TanimotoComplementLoss() + self.crop_loss_deep_c = TanimotoComplementLoss() # Crop Temporal encoding losses - self.classes_l2_loss = TanimotoDistLoss() - self.classes_last_loss = TanimotoDistLoss() + self.classes_l2_loss = TanimotoComplementLoss() + self.classes_last_loss = TanimotoComplementLoss() if self.num_classes > 2: - self.crop_type_star_loss = TanimotoDistLoss( + self.crop_type_star_loss = TanimotoComplementLoss( scale_pos_weight=self.scale_pos_weight ) - self.crop_type_loss = TanimotoDistLoss( + self.crop_type_loss = TanimotoComplementLoss( scale_pos_weight=self.scale_pos_weight ) diff --git a/src/cultionet/models/nunet.py b/src/cultionet/models/nunet.py index 1cabf215..088ac2df 100644 --- a/src/cultionet/models/nunet.py +++ b/src/cultionet/models/nunet.py @@ -697,14 +697,38 @@ def __init__( in_channels, in_channels * 3, kernel_size=1, padding=0 ) self.final_dist = nn.Sequential( + cunn.ConvBlock2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + padding=1, + add_activation=True, + activation_type="SiLU", + ), nn.Conv2d(in_channels, 1, kernel_size=1, padding=0), nn.Sigmoid(), ) self.final_edge = nn.Sequential( + cunn.ConvBlock2d( + in_channels=in_channels + 1, + out_channels=in_channels, + kernel_size=3, + padding=1, + add_activation=True, + activation_type="SiLU", + ), nn.Conv2d(in_channels, 1, kernel_size=1, padding=0), cunn.SigmoidCrisp(), ) self.final_mask = nn.Sequential( + cunn.ConvBlock2d( + in_channels=in_channels + 2, + out_channels=in_channels, + kernel_size=3, + padding=1, + add_activation=True, + activation_type="SiLU", + ), nn.Conv2d(in_channels, num_classes, kernel_size=1, padding=0), mask_activation, ) @@ -723,14 +747,16 @@ def forward( mode="bilinear", ) - dist, edge, mask = torch.chunk(self.expand(x), 3, dim=1) + dist_connect, edge_connect, mask_connect = torch.chunk( + self.expand(x), 3, dim=1 + ) - if foj_boundaries is not None: - edge = edge * foj_boundaries + # if foj_boundaries is not None: + # edge = edge * foj_boundaries - dist = self.final_dist(dist) - edge = self.final_edge(edge) - mask = self.final_mask(mask) + dist = self.final_dist(dist_connect) + edge = self.final_edge(torch.cat((edge_connect, dist), dim=1)) + mask = self.final_mask(torch.cat((mask_connect, dist, edge), dim=1)) return { f"dist{suffix}": dist, @@ -750,6 +776,7 @@ def __init__( num_classes: int = 2, dilations: T.Sequence[int] = None, activation_type: str = "SiLU", + dropout: float = 0.0, res_block_type: str = ResBlockTypes.RES, attention_weights: str = AttentionTypes.SPATIAL_CHANNEL, mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1), @@ -796,6 +823,7 @@ def __init__( self.down_b = cunn.PoolResidualConv( channels[0], channels[1], + dropout=dropout, attention_weights=attention_weights, res_block_type=res_block_type, dilations=dilations, @@ -803,6 +831,7 @@ def __init__( self.down_c = cunn.PoolResidualConv( channels[1], channels[2], + dropout=dropout, activation_type=activation_type, attention_weights=attention_weights, res_block_type=res_block_type, @@ -811,8 +840,9 @@ def __init__( self.down_d = cunn.PoolResidualConv( channels[2], channels[3], - num_blocks=1, + dropout=dropout, kernel_size=1, + num_blocks=1, activation_type=activation_type, attention_weights=attention_weights, res_block_type=res_block_type, @@ -947,8 +977,10 @@ def forward( x_c = self.down_c(x_b) x_d = self.down_d(x_c) - # Up + # Over x_du = self.up_du(x_d, shape=x_d.shape[-2:]) + + # Up x_cu = self.up_cu(x_du, shape=x_c.shape[-2:]) x_bu = self.up_bu(x_cu, shape=x_b.shape[-2:]) x_au = self.up_au(x_bu, shape=x_a.shape[-2:]) diff --git a/src/cultionet/nn/modules/convolution.py b/src/cultionet/nn/modules/convolution.py index a4ce9d3b..dfad4a8f 100644 --- a/src/cultionet/nn/modules/convolution.py +++ b/src/cultionet/nn/modules/convolution.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops.layers.torch import Rearrange from cultionet.enums import AttentionTypes, ResBlockTypes @@ -727,8 +728,7 @@ def __init__( self, in_channels: int, out_channels: int, - pool_size: int = 2, - dropout: T.Optional[float] = None, + dropout: float = 0.0, kernel_size: int = 3, num_blocks: int = 2, attention_weights: T.Optional[str] = None, @@ -743,41 +743,43 @@ def __init__( ResBlockTypes.RESA, ) - layers = [nn.MaxPool2d(pool_size)] - - if dropout is not None: - assert isinstance( - dropout, float - ), "The dropout arg must be a float." - layers += [nn.Dropout(dropout)] - if res_block_type == ResBlockTypes.RES: - layers += [ - ResidualConv( - in_channels, - out_channels, - kernel_size=kernel_size, - attention_weights=attention_weights, - num_blocks=num_blocks, - activation_type=activation_type, - ) - ] + self.conv = ResidualConv( + in_channels, + out_channels, + kernel_size=kernel_size, + attention_weights=attention_weights, + num_blocks=num_blocks, + activation_type=activation_type, + ) else: - layers += [ - ResidualAConv( - in_channels, - out_channels, - kernel_size=kernel_size, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - ] + self.conv = ResidualAConv( + in_channels, + out_channels, + kernel_size=kernel_size, + dilations=dilations, + attention_weights=attention_weights, + activation_type=activation_type, + ) - self.seq = nn.Sequential(*layers) + self.dropout_layer = None + if dropout > 0: + self.dropout_layer = nn.Dropout2d(p=dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) + height, width = x.shape[-2:] + + # Apply convolutions + x = self.conv(x) + + # Max pooling + x = F.adaptive_max_pool2d(x, output_size=(height // 2, width // 2)) + + # Optional dropout + if self.dropout_layer is not None: + x = self.dropout_layer(x) + + return x class SingleConv3d(nn.Module): diff --git a/src/cultionet/scripts/args.yml b/src/cultionet/scripts/args.yml index fd2b4e0b..3885618c 100644 --- a/src/cultionet/scripts/args.yml +++ b/src/cultionet/scripts/args.yml @@ -221,6 +221,13 @@ train_predict: kwargs: default: 'res' choices: ['res', 'resa'] + dropout: + short: '' + long: dropout + help: The dropout probability + kwargs: + default: 0.0 + type: '&float' dilations: short: '' long: dilations @@ -500,7 +507,7 @@ train: long: learning-rate help: The learning rate kwargs: - default: 0.01 + default: 0.02 type: '&float' lr_scheduler: short: lrs @@ -531,7 +538,7 @@ train: long: weight-decay help: Sets the weight decay for Adam optimizer\'s regularization kwargs: - default: 1e-4 + default: 2e-3 type: '&float' accumulate_grad_batches: short: agb @@ -608,6 +615,13 @@ train: help: Whether to finetune a transfer model (otherwise, do feature extraction) kwargs: action: store_true + strategy: + short: '' + long: strategy + help: The model distribution strategy + kwargs: + default: 'ddp' + choices: ['ddp', 'ddp_spawn', 'fsdp'] predict: out_path: diff --git a/src/cultionet/scripts/cultionet.py b/src/cultionet/scripts/cultionet.py index 348105a0..cedde068 100644 --- a/src/cultionet/scripts/cultionet.py +++ b/src/cultionet/scripts/cultionet.py @@ -1192,6 +1192,7 @@ def train_model(args): save_top_k=args.save_top_k, accumulate_grad_batches=args.accumulate_grad_batches, model_type=args.model_type, + dropout=args.dropout, dilations=args.dilations, res_block_type=args.res_block_type, attention_weights=args.attention_weights, @@ -1228,6 +1229,7 @@ def train_model(args): skip_train=args.skip_train, refine_model=args.refine_model, finetune=args.finetune, + strategy=args.strategy, ) # Fit the model