Skip to content

Commit

Permalink
fix dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed May 3, 2024
1 parent dd08d4c commit 024bd44
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 107 deletions.
27 changes: 17 additions & 10 deletions src/cultionet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ModelPruning,
StochasticWeightAveraging,
)
from lightning.pytorch.tuner import Tuner
from rasterio.windows import Window
from scipy.stats import mode as sci_mode
from torchvision import transforms
Expand Down Expand Up @@ -243,21 +244,21 @@ def get_data_module(

def setup_callbacks(
ckpt_file: T.Union[str, Path],
save_top_k: T.Optional[int] = 1,
early_stopping_min_delta: T.Optional[float] = 0.01,
early_stopping_patience: T.Optional[int] = 7,
stochastic_weight_averaging: T.Optional[bool] = False,
stochastic_weight_averaging_lr: T.Optional[float] = 0.05,
stochastic_weight_averaging_start: T.Optional[float] = 0.8,
model_pruning: T.Optional[bool] = False,
save_top_k: int = 1,
early_stopping_min_delta: float = 0.01,
early_stopping_patience: int = 7,
stochastic_weight_averaging: bool = False,
stochastic_weight_averaging_lr: float = 0.05,
stochastic_weight_averaging_start: float = 0.8,
model_pruning: bool = False,
) -> T.Tuple[LearningRateMonitor, T.Sequence[T.Any]]:
# Checkpoint
cb_train_loss = ModelCheckpoint(monitor="loss")
# Validation and test loss
cb_val_loss = ModelCheckpoint(
dirpath=ckpt_file.parent,
filename=ckpt_file.stem,
save_last=True,
save_last=False,
save_top_k=save_top_k,
mode="min",
monitor="val_score",
Expand Down Expand Up @@ -578,6 +579,7 @@ def fit(
class_counts: T.Sequence[float] = None,
model_type: str = ModelTypes.RESUNET3PSI,
activation_type: str = "SiLU",
dropout: float = 0.0,
dilations: T.Union[int, T.Sequence[int]] = None,
res_block_type: str = ResBlockTypes.RES,
attention_weights: str = AttentionTypes.SPATIAL_CHANNEL,
Expand Down Expand Up @@ -608,6 +610,7 @@ def fit(
skip_train: bool = False,
refine_model: bool = False,
finetune: bool = False,
strategy: str = "ddp",
):
"""Fits a model.
Expand Down Expand Up @@ -661,6 +664,7 @@ def fit(
skip_train (Optional[bool]): Whether to refine and calibrate a trained model.
refine_model (Optional[bool]): Whether to skip training.
finetune (bool): Not used. Placeholder for compatibility with transfer learning.
strategy (str): The model distributed strategy.
"""
ckpt_file = Path(ckpt_file)

Expand Down Expand Up @@ -727,15 +731,18 @@ def fit(
precision=precision,
devices=devices,
accelerator=device,
strategy='ddp',
strategy=strategy,
log_every_n_steps=50,
profiler=profiler,
deterministic=False,
benchmark=False,
)

if auto_lr_find:
trainer.tune(model=lit_model, datamodule=data_module)
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model=lit_model, datamodule=data_module)
opt_lr = lr_finder.suggestion()
logger.info(f"The suggested learning rate is {opt_lr}")
else:
if not skip_train:
trainer.fit(
Expand Down
44 changes: 11 additions & 33 deletions src/cultionet/models/cultionet.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def __init__(
num_classes: int = 2,
model_type: str = ModelTypes.TOWERUNET,
activation_type: str = "SiLU",
dropout: float = 0.1,
dilations: T.Union[int, T.Sequence[int]] = None,
res_block_type: str = ResBlockTypes.RES,
attention_weights: str = "spatial_channel",
Expand All @@ -290,7 +291,7 @@ def __init__(
hidden_channels=self.hidden_channels,
num_head=8,
in_time=self.in_time,
dropout=0.1,
dropout=0.2,
num_layers=2,
d_model=128,
time_scaler=100,
Expand All @@ -305,6 +306,10 @@ def __init__(
"in_time": self.in_time,
"hidden_channels": self.hidden_channels,
"num_classes": self.num_classes,
"attention_weights": attention_weights,
"res_block_type": res_block_type,
"dropout": dropout,
"dilations": dilations,
"activation_type": activation_type,
"deep_supervision": deep_supervision,
"mask_activation": nn.Softmax(dim=1),
Expand All @@ -315,40 +320,13 @@ def __init__(
ModelTypes.RESUNET3PSI,
ModelTypes.TOWERUNET,
), "The model type is not supported."

if model_type == ModelTypes.UNET3PSI:
unet3_kwargs["dilation"] = 2 if dilations is None else dilations
assert isinstance(
unet3_kwargs["dilation"], int
), f"The dilation for {ModelTypes.UNET3PSI} must be an integer."
self.mask_model = UNet3Psi(**unet3_kwargs)
elif model_type in (
ModelTypes.RESUNET3PSI,
ModelTypes.TOWERUNET,
):
# ResUNet3Psi
unet3_kwargs["attention_weights"] = (
None if attention_weights == "none" else attention_weights
)
unet3_kwargs["res_block_type"] = res_block_type
if res_block_type == ResBlockTypes.RES:
unet3_kwargs["dilations"] = (
[2] if dilations is None else dilations
)
assert (
len(unet3_kwargs["dilations"]) == 1
), f"The dilations for {ModelTypes.RESUNET3PSI} must be a length-1 integer sequence."
elif res_block_type == ResBlockTypes.RESA:
unet3_kwargs["dilations"] = (
[1, 2] if dilations is None else dilations
)
assert isinstance(
unet3_kwargs["dilations"], list
), f"The dilations for {ModelTypes.RESUNET3PSI} must be a sequence of integers."

if model_type == ModelTypes.RESUNET3PSI:
self.mask_model = ResUNet3Psi(**unet3_kwargs)
else:
self.mask_model = TowerUNet(**unet3_kwargs)
elif model_type == ModelTypes.RESUNET3PSI:
self.mask_model = ResUNet3Psi(**unet3_kwargs)
else:
self.mask_model = TowerUNet(**unet3_kwargs)

def forward(self, batch: Data) -> T.Dict[str, torch.Tensor]:
# Transformer attention encoder
Expand Down
48 changes: 26 additions & 22 deletions src/cultionet/models/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..data.data import Data
from ..enums import LearningRateSchedulers, ModelTypes, ResBlockTypes
from ..layers.weights import init_attention_weights
from ..losses import FieldOfJunctionsLoss, TanimotoDistLoss
from ..losses import TanimotoComplementLoss, TanimotoDistLoss
from .cultionet import CultioNet, GeoRefinement
from .maskcrnn import BFasterRCNN
from .nunet import PostUNet3Psi
Expand Down Expand Up @@ -523,13 +523,13 @@ def get_true_labels(
"true_crop_type": true_crop_type,
}

def on_validation_epoch_end(self, *args, **kwargs):
"""Save the model on validation end."""
if self.logger.save_dir is not None:
model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt"
if model_file.is_file():
model_file.unlink()
torch.save(self.state_dict(), model_file)
# def on_validation_epoch_end(self, *args, **kwargs):
# """Save the model on validation end."""
# if self.logger.save_dir is not None:
# model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt"
# if model_file.is_file():
# model_file.unlink()
# torch.save(self.state_dict(), model_file)

def calc_loss(
self,
Expand Down Expand Up @@ -846,30 +846,34 @@ def configure_scorer(self):

def configure_loss(self):
# Distance loss
self.dist_loss = TanimotoDistLoss(one_hot_targets=False)
self.dist_loss = TanimotoComplementLoss(one_hot_targets=False)
# Edge loss
self.edge_loss = TanimotoDistLoss()
self.edge_loss = TanimotoComplementLoss()
# Crop mask loss
self.crop_loss = TanimotoDistLoss()
self.crop_loss = TanimotoComplementLoss()
# Field of junctions loss
self.foj_loss = FieldOfJunctionsLoss()
# self.foj_loss = FieldOfJunctionsLoss()

if self.deep_supervision:
self.dist_loss_deep_b = TanimotoDistLoss(one_hot_targets=False)
self.edge_loss_deep_b = TanimotoDistLoss()
self.crop_loss_deep_b = TanimotoDistLoss()
self.dist_loss_deep_c = TanimotoDistLoss(one_hot_targets=False)
self.edge_loss_deep_c = TanimotoDistLoss()
self.crop_loss_deep_c = TanimotoDistLoss()
self.dist_loss_deep_b = TanimotoComplementLoss(
one_hot_targets=False
)
self.edge_loss_deep_b = TanimotoComplementLoss()
self.crop_loss_deep_b = TanimotoComplementLoss()
self.dist_loss_deep_c = TanimotoComplementLoss(
one_hot_targets=False
)
self.edge_loss_deep_c = TanimotoComplementLoss()
self.crop_loss_deep_c = TanimotoComplementLoss()

# Crop Temporal encoding losses
self.classes_l2_loss = TanimotoDistLoss()
self.classes_last_loss = TanimotoDistLoss()
self.classes_l2_loss = TanimotoComplementLoss()
self.classes_last_loss = TanimotoComplementLoss()
if self.num_classes > 2:
self.crop_type_star_loss = TanimotoDistLoss(
self.crop_type_star_loss = TanimotoComplementLoss(
scale_pos_weight=self.scale_pos_weight
)
self.crop_type_loss = TanimotoDistLoss(
self.crop_type_loss = TanimotoComplementLoss(
scale_pos_weight=self.scale_pos_weight
)

Expand Down
48 changes: 40 additions & 8 deletions src/cultionet/models/nunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,38 @@ def __init__(
in_channels, in_channels * 3, kernel_size=1, padding=0
)
self.final_dist = nn.Sequential(
cunn.ConvBlock2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
padding=1,
add_activation=True,
activation_type="SiLU",
),
nn.Conv2d(in_channels, 1, kernel_size=1, padding=0),
nn.Sigmoid(),
)
self.final_edge = nn.Sequential(
cunn.ConvBlock2d(
in_channels=in_channels + 1,
out_channels=in_channels,
kernel_size=3,
padding=1,
add_activation=True,
activation_type="SiLU",
),
nn.Conv2d(in_channels, 1, kernel_size=1, padding=0),
cunn.SigmoidCrisp(),
)
self.final_mask = nn.Sequential(
cunn.ConvBlock2d(
in_channels=in_channels + 2,
out_channels=in_channels,
kernel_size=3,
padding=1,
add_activation=True,
activation_type="SiLU",
),
nn.Conv2d(in_channels, num_classes, kernel_size=1, padding=0),
mask_activation,
)
Expand All @@ -723,14 +747,16 @@ def forward(
mode="bilinear",
)

dist, edge, mask = torch.chunk(self.expand(x), 3, dim=1)
dist_connect, edge_connect, mask_connect = torch.chunk(
self.expand(x), 3, dim=1
)

if foj_boundaries is not None:
edge = edge * foj_boundaries
# if foj_boundaries is not None:
# edge = edge * foj_boundaries

dist = self.final_dist(dist)
edge = self.final_edge(edge)
mask = self.final_mask(mask)
dist = self.final_dist(dist_connect)
edge = self.final_edge(torch.cat((edge_connect, dist), dim=1))
mask = self.final_mask(torch.cat((mask_connect, dist, edge), dim=1))

return {
f"dist{suffix}": dist,
Expand All @@ -750,6 +776,7 @@ def __init__(
num_classes: int = 2,
dilations: T.Sequence[int] = None,
activation_type: str = "SiLU",
dropout: float = 0.0,
res_block_type: str = ResBlockTypes.RES,
attention_weights: str = AttentionTypes.SPATIAL_CHANNEL,
mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1),
Expand Down Expand Up @@ -796,13 +823,15 @@ def __init__(
self.down_b = cunn.PoolResidualConv(
channels[0],
channels[1],
dropout=dropout,
attention_weights=attention_weights,
res_block_type=res_block_type,
dilations=dilations,
)
self.down_c = cunn.PoolResidualConv(
channels[1],
channels[2],
dropout=dropout,
activation_type=activation_type,
attention_weights=attention_weights,
res_block_type=res_block_type,
Expand All @@ -811,8 +840,9 @@ def __init__(
self.down_d = cunn.PoolResidualConv(
channels[2],
channels[3],
num_blocks=1,
dropout=dropout,
kernel_size=1,
num_blocks=1,
activation_type=activation_type,
attention_weights=attention_weights,
res_block_type=res_block_type,
Expand Down Expand Up @@ -947,8 +977,10 @@ def forward(
x_c = self.down_c(x_b)
x_d = self.down_d(x_c)

# Up
# Over
x_du = self.up_du(x_d, shape=x_d.shape[-2:])

# Up
x_cu = self.up_cu(x_du, shape=x_c.shape[-2:])
x_bu = self.up_bu(x_cu, shape=x_b.shape[-2:])
x_au = self.up_au(x_bu, shape=x_a.shape[-2:])
Expand Down
Loading

0 comments on commit 024bd44

Please sign in to comment.