Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed Apr 30, 2024
1 parent a61d6af commit ef2afc2
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 204 deletions.
8 changes: 2 additions & 6 deletions src/cultionet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,9 +581,7 @@ def fit(
dilations: T.Union[int, T.Sequence[int]] = None,
res_block_type: str = ResBlockTypes.RES,
attention_weights: str = AttentionTypes.SPATIAL_CHANNEL,
deep_sup_dist: bool = False,
deep_sup_edge: bool = False,
deep_sup_mask: bool = False,
deep_supervision: bool = False,
optimizer: str = "AdamW",
learning_rate: float = 1e-3,
lr_scheduler: str = "CosineAnnealingLR",
Expand Down Expand Up @@ -687,9 +685,7 @@ def fit(
dilations=dilations,
res_block_type=res_block_type,
attention_weights=attention_weights,
deep_sup_dist=deep_sup_dist,
deep_sup_edge=deep_sup_edge,
deep_sup_mask=deep_sup_mask,
deep_supervision=deep_supervision,
optimizer=optimizer,
learning_rate=learning_rate,
lr_scheduler=lr_scheduler,
Expand Down
12 changes: 3 additions & 9 deletions src/cultionet/models/cultionet.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,7 @@ class CultioNet(nn.Module):
dilations (int | list): The convolution dilation or dilations.
res_block_type (str): The residual convolution block type.
attention_weights (str): The attention weight type.
deep_sup_dist (bool): Whether to use deep supervision on the distance layer.
deep_sup_edge (bool): Whether to use deep supervision on the edge layer.
deep_sup_mask (bool): Whether to use deep supervision on the mask layer.
deep_supervision (bool): Whether to use deep supervision.
"""

def __init__(
Expand All @@ -278,9 +276,7 @@ def __init__(
dilations: T.Union[int, T.Sequence[int]] = None,
res_block_type: str = ResBlockTypes.RES,
attention_weights: str = "spatial_channel",
deep_sup_dist: bool = False,
deep_sup_edge: bool = False,
deep_sup_mask: bool = False,
deep_supervision: bool = False,
):
super(CultioNet, self).__init__()

Expand Down Expand Up @@ -310,9 +306,7 @@ def __init__(
"hidden_channels": self.hidden_channels,
"num_classes": self.num_classes,
"activation_type": activation_type,
# "deep_sup_dist": deep_sup_dist,
# "deep_sup_edge": deep_sup_edge,
# "deep_sup_mask": deep_sup_mask,
"deep_supervision": deep_supervision,
"mask_activation": nn.Softmax(dim=1),
}

Expand Down
129 changes: 30 additions & 99 deletions src/cultionet/models/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,17 +545,8 @@ def calc_loss(
"l2": 0.25,
"l3": 0.5,
"dist_loss": 1.0,
"dist_loss_3_1": 0.1,
"dist_loss_2_2": 0.25,
"dist_loss_1_3": 0.5,
"edge_loss": 1.0,
"edge_loss_3_1": 0.1,
"edge_loss_2_2": 0.25,
"edge_loss_1_3": 0.5,
"crop_loss": 1.0,
"crop_loss_3_1": 0.1,
"crop_loss_2_2": 0.25,
"crop_loss_1_3": 0.5,
}

true_labels_dict = self.get_true_labels(
Expand All @@ -580,77 +571,43 @@ def calc_loss(
loss = loss + classes_last_loss * weights["l3"]

# Edge losses
if self.deep_sup_dist:
dist_loss_3_1 = self.dist_loss_3_1(
predictions["dist_3_1"], batch.bdist
if self.deep_supervision:
dist_loss_deep = self.dist_loss_deep(
predictions["dist_deep"], batch.bdist
)
dist_loss_2_2 = self.dist_loss_2_2(
predictions["dist_2_2"], batch.bdist
edge_loss_deep = self.edge_loss_deep(
predictions["edge_deep"], true_labels_dict["true_edge"]
)
dist_loss_1_3 = self.dist_loss_1_3(
predictions["dist_1_3"], batch.bdist
crop_loss_deep = self.crop_loss_deep(
predictions["crop_deep"], true_labels_dict["true_crop"]
)

weights["dist_loss_deep"] = 0.5
weights["edge_loss_deep"] = 0.5
weights["crop_loss_deep"] = 0.5

# Main loss
loss = (
loss
+ dist_loss_3_1 * weights["dist_loss_3_1"]
+ dist_loss_2_2 * weights["dist_loss_2_2"]
+ dist_loss_1_3 * weights["dist_loss_1_3"]
+ dist_loss_deep * weights["dist_loss_deep"]
+ edge_loss_deep * weights["edge_loss_deep"]
+ crop_loss_deep * weights["crop_loss_deep"]
)

# Distance transform loss
dist_loss = self.dist_loss(predictions["dist"], batch.bdist)
# Main loss
loss = loss + dist_loss * weights["dist_loss"]

# Distance transform losses
if self.deep_sup_edge:
edge_loss_3_1 = self.edge_loss_3_1(
predictions["edge_3_1"], true_labels_dict["true_edge"]
)
edge_loss_2_2 = self.edge_loss_2_2(
predictions["edge_2_2"], true_labels_dict["true_edge"]
)
edge_loss_1_3 = self.edge_loss_1_3(
predictions["edge_1_3"], true_labels_dict["true_edge"]
)
# Main loss
loss = (
loss
+ edge_loss_3_1 * weights["edge_loss_3_1"]
+ edge_loss_2_2 * weights["edge_loss_2_2"]
+ edge_loss_1_3 * weights["edge_loss_1_3"]
)
# Edge loss
edge_loss = self.edge_loss(
predictions["edge"], true_labels_dict["true_edge"]
)
# Main loss
loss = loss + edge_loss * weights["edge_loss"]

# Crop mask losses
if self.deep_sup_mask:
crop_loss_3_1 = self.crop_loss_3_1(
predictions["crop_3_1"], true_labels_dict["true_crop"]
)
crop_loss_2_2 = self.crop_loss_2_2(
predictions["crop_2_2"], true_labels_dict["true_crop"]
)
crop_loss_1_3 = self.crop_loss_1_3(
predictions["crop_1_3"], true_labels_dict["true_crop"]
)
# Main loss
loss = (
loss
+ crop_loss_3_1 * weights["crop_loss_3_1"]
+ crop_loss_2_2 * weights["crop_loss_2_2"]
+ crop_loss_1_3 * weights["crop_loss_1_3"]
)

# Crop mask loss
crop_loss = self.crop_loss(
predictions["crop"], true_labels_dict["true_crop"]
)
# Main loss
loss = loss + crop_loss * weights["crop_loss"]

# if predictions["crop_type"] is not None:
Expand Down Expand Up @@ -863,31 +820,17 @@ def configure_scorer(self):
)

def configure_loss(self):
# Distance loss
self.dist_loss = TanimotoDistLoss(one_hot_targets=False)
if self.deep_sup_dist:
self.dist_loss_3_1 = TanimotoDistLoss(one_hot_targets=False)
self.dist_loss_2_2 = TanimotoDistLoss(one_hot_targets=False)
self.dist_loss_1_3 = TanimotoDistLoss(one_hot_targets=False)

# Edge losses
# Edge losse
self.edge_loss = TanimotoDistLoss()
if self.deep_sup_edge:
self.edge_loss_3_1 = TanimotoDistLoss()
self.edge_loss_2_2 = TanimotoDistLoss()
self.edge_loss_1_3 = TanimotoDistLoss()

# Crop mask losses
# Crop mask losse
self.crop_loss = TanimotoDistLoss()
if self.deep_sup_mask:
self.crop_loss_3_1 = TanimotoDistLoss(
scale_pos_weight=self.scale_pos_weight
)
self.crop_loss_2_2 = TanimotoDistLoss(
scale_pos_weight=self.scale_pos_weight
)
self.crop_loss_1_3 = TanimotoDistLoss(
scale_pos_weight=self.scale_pos_weight
)

if self.deep_supervision:
self.dist_loss_deep = TanimotoDistLoss(one_hot_targets=False)
self.edge_loss_deep = TanimotoDistLoss()
self.crop_loss_deep = TanimotoDistLoss()

# Crop Temporal encoding losses
self.classes_l2_loss = TanimotoDistLoss()
Expand Down Expand Up @@ -976,9 +919,7 @@ def __init__(
weight_decay: float = 0.01,
eps: float = 1e-4,
mask_activation: T.Callable = nn.Softmax(dim=1),
deep_sup_dist: bool = True,
deep_sup_edge: bool = True,
deep_sup_mask: bool = True,
deep_supervision: bool = True,
scale_pos_weight: bool = True,
model_name: str = "cultionet_transfer",
edge_class: T.Optional[int] = None,
Expand Down Expand Up @@ -1007,9 +948,7 @@ def __init__(
up_channels = int(init_filter * 5)
self.in_channels = in_channels
self.num_time = num_time
self.deep_sup_dist = deep_sup_dist
self.deep_sup_edge = deep_sup_edge
self.deep_sup_mask = deep_sup_mask
self.deep_supervision = deep_supervision
self.scale_pos_weight = scale_pos_weight

self.cultionet_model = CultionetLitModel.load_from_checkpoint(
Expand Down Expand Up @@ -1070,9 +1009,7 @@ def __init__(
up_channels=up_channels,
num_classes=num_classes,
mask_activation=mask_activation,
deep_sup_dist=deep_sup_dist,
deep_sup_edge=deep_sup_edge,
deep_sup_mask=deep_sup_mask,
deep_supervision=deep_supervision,
)
self.cultionet_model.mask_model.post_unet = post_unet

Expand Down Expand Up @@ -1112,9 +1049,7 @@ def __init__(
eps: float = 1e-4,
ckpt_name: str = "last",
model_name: str = "cultionet",
deep_sup_dist: bool = False,
deep_sup_edge: bool = False,
deep_sup_mask: bool = False,
deep_supervision: bool = False,
class_counts: T.Optional[torch.Tensor] = None,
edge_class: T.Optional[int] = None,
temperature_lit_model: T.Optional[GeoRefinement] = None,
Expand All @@ -1141,9 +1076,7 @@ def __init__(
self.temperature_lit_model = temperature_lit_model
self.scale_pos_weight = scale_pos_weight
self.save_batch_val_metrics = save_batch_val_metrics
self.deep_sup_dist = deep_sup_dist
self.deep_sup_edge = deep_sup_edge
self.deep_sup_mask = deep_sup_mask
self.deep_supervision = deep_supervision
self.sigmoid = torch.nn.Sigmoid()
if edge_class is not None:
self.edge_class = edge_class
Expand All @@ -1164,9 +1097,7 @@ def __init__(
dilations=dilations,
res_block_type=res_block_type,
attention_weights=attention_weights,
deep_sup_dist=deep_sup_dist,
deep_sup_edge=deep_sup_edge,
deep_sup_mask=deep_sup_mask,
deep_supervision=deep_supervision,
),
)

Expand Down
27 changes: 14 additions & 13 deletions src/cultionet/models/nunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,10 @@ def __init__(
res_block_type: str = ResBlockTypes.RES,
attention_weights: T.Optional[str] = None,
mask_activation: T.Union[nn.Softmax, nn.Sigmoid] = nn.Softmax(dim=1),
deep_supervision: bool = False,
):
super(TowerUNet, self).__init__()

if dilations is None:
dilations = [2]
if attention_weights is None:
attention_weights = "spatial_channel"

Expand All @@ -713,7 +712,7 @@ def __init__(
self.down_a = cunn.ResidualConv(
in_channels=channels[0],
out_channels=channels[0],
dilation=dilations[0],
num_blocks=2,
activation_type=activation_type,
attention_weights=attention_weights,
)
Expand All @@ -729,22 +728,23 @@ def __init__(
self.down_b = cunn.PoolResidualConv(
channels[0],
channels[1],
dilations=dilations,
num_blocks=1,
attention_weights=attention_weights,
res_block_type=res_block_type,
)
self.down_c = cunn.PoolResidualConv(
channels[1],
channels[2],
dilations=dilations,
num_blocks=1,
activation_type=activation_type,
attention_weights=attention_weights,
res_block_type=res_block_type,
)
self.down_d = cunn.PoolResidualConv(
channels[2],
channels[3],
dilations=dilations,
num_blocks=1,
kernel_size=1,
activation_type=activation_type,
attention_weights=attention_weights,
res_block_type=res_block_type,
Expand All @@ -754,28 +754,29 @@ def __init__(
self.up_e = cunn.TowerUNetUpLayer(
in_channels=channels[3],
out_channels=up_channels,
dilations=dilations,
num_blocks=1,
kernel_size=1,
attention_weights=attention_weights,
activation_type=activation_type,
)
self.up_f = cunn.TowerUNetUpLayer(
in_channels=up_channels,
out_channels=up_channels,
dilations=dilations,
num_blocks=1,
attention_weights=attention_weights,
activation_type=activation_type,
)
self.up_g = cunn.TowerUNetUpLayer(
in_channels=up_channels,
out_channels=up_channels,
dilations=dilations,
num_blocks=1,
attention_weights=attention_weights,
activation_type=activation_type,
)
self.up_h = cunn.TowerUNetUpLayer(
in_channels=up_channels,
out_channels=up_channels,
dilations=dilations,
num_blocks=2,
attention_weights=attention_weights,
activation_type=activation_type,
)
Expand All @@ -786,7 +787,7 @@ def __init__(
backbone_down_channels=channels[3],
up_channels=up_channels,
out_channels=up_channels,
dilations=dilations,
num_blocks=1,
attention_weights=attention_weights,
activation_type=activation_type,
)
Expand All @@ -797,7 +798,7 @@ def __init__(
up_channels=up_channels,
out_channels=up_channels,
tower=True,
dilations=dilations,
num_blocks=1,
attention_weights=attention_weights,
activation_type=activation_type,
)
Expand All @@ -808,7 +809,7 @@ def __init__(
up_channels=up_channels,
out_channels=up_channels,
tower=True,
dilations=dilations,
num_blocks=2,
attention_weights=attention_weights,
activation_type=activation_type,
)
Expand Down
Loading

0 comments on commit ef2afc2

Please sign in to comment.