From a4047e9be477665fa01d45b02ce57db46b5e86a4 Mon Sep 17 00:00:00 2001 From: Isaac Seessel Date: Sun, 17 Oct 2021 13:45:23 -0400 Subject: [PATCH] BYOL improvements --- .../imagenet1k/byol_transfer_in1k_linear.yaml | 25 +++++++--- .../pretrain/byol/byol_8node_resnet.yaml | 9 ++-- .../img_pil_color_distortion.py | 21 +++++--- vissl/hooks/__init__.py | 12 +---- vissl/hooks/byol_hooks.py | 50 ++++++++++--------- vissl/losses/byol_loss.py | 23 ++++----- 6 files changed, 75 insertions(+), 65 deletions(-) diff --git a/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml b/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml index 011aaf920..1bdc74c1c 100644 --- a/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml +++ b/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml @@ -22,6 +22,7 @@ config: TRANSFORMS: - name: RandomResizedCrop size: 224 + interpolation: 3 - name: RandomHorizontalFlip - name: ToTensor - name: Normalize @@ -38,6 +39,7 @@ config: TRANSFORMS: - name: Resize size: 256 + interpolation: 3 - name: CenterCrop size: 224 - name: ToTensor @@ -82,7 +84,7 @@ config: PARAMS_FILE: "specify the model weights" STATE_DICT_KEY_NAME: classy_state_dict SYNC_BN_CONFIG: - CONVERT_BN_TO_SYNC_BN: True + CONVERT_BN_TO_SYNC_BN: False SYNC_BN_TYPE: apex GROUP_SIZE: 8 LOSS: @@ -93,22 +95,29 @@ config: name: sgd momentum: 0.9 num_epochs: 80 + weight_decay: 0 nesterov: True regularize_bn: False regularize_bias: True param_schedulers: lr: auto_lr_scaling: - auto_scale: true - base_value: 0.4 + # if set to True, learning rate will be scaled. + auto_scale: True + # base learning rate value that will be scaled. + base_value: 0.2 + # batch size for which the base learning rate is specified. The current batch size + # is used to determine how to scale the base learning rate value. + # scaled_lr = ((batchsize_per_gpu * world_size) * base_value ) / base_lr_batch_size base_lr_batch_size: 256 - name: multistep - values: [0.4, 0.3, 0.2, 0.1, 0.05] - milestones: [16, 32, 48, 64] - update_interval: epoch + # scaling_type can be set to "sqrt" to reduce the impact of scaling on the base value + scaling_type: "linear" + name: constant + update_interval: "epoch" + value: 0.2 DISTRIBUTED: BACKEND: nccl - NUM_NODES: 8 + NUM_NODES: 4 NUM_PROC_PER_NODE: 8 INIT_METHOD: tcp RUN_ID: auto diff --git a/configs/config/pretrain/byol/byol_8node_resnet.yaml b/configs/config/pretrain/byol/byol_8node_resnet.yaml index e23822b91..ab4665b79 100644 --- a/configs/config/pretrain/byol/byol_8node_resnet.yaml +++ b/configs/config/pretrain/byol/byol_8node_resnet.yaml @@ -67,7 +67,7 @@ config: RESNETS: DEPTH: 50 ZERO_INIT_RESIDUAL: True - HEAD: + HEAD: PARAMS: [ ["mlp", {"dims": [2048, 4096, 256], "use_relu": True, "use_bn": True}], ["mlp", {"dims": [256, 4096, 256], "use_relu": True, "use_bn": True}] @@ -82,15 +82,16 @@ config: byol_loss: embedding_dim: 256 momentum: 0.99 - OPTIMIZER: # from official BYOL implementation, deepmind-research/byol/configs/byol.py + OPTIMIZER: name: lars - trust_coefficient: 0.001 + eta: 0.001 weight_decay: 1.0e-6 momentum: 0.9 nesterov: False num_epochs: 300 regularize_bn: False - regularize_bias: True + regularize_bias: False + exclude_bias_and_norm: True param_schedulers: lr: auto_lr_scaling: diff --git a/vissl/data/ssl_transforms/img_pil_color_distortion.py b/vissl/data/ssl_transforms/img_pil_color_distortion.py index e3f79e4ca..df8f953b6 100644 --- a/vissl/data/ssl_transforms/img_pil_color_distortion.py +++ b/vissl/data/ssl_transforms/img_pil_color_distortion.py @@ -21,8 +21,16 @@ class ImgPilColorDistortion(ClassyTransform): randomly convert the image to grayscale. """ - def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8, - hue=0.2, color_jitter_probability=0.8, grayscale_probability=0.2): + def __init__( + self, + strength, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2, + color_jitter_probability=0.8, + grayscale_probability=0.2, + ): """ Args: strength (float): A number used to quantify the strength of the @@ -41,22 +49,23 @@ def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8, grayscale_probability (float): A floating point number used to quantify to apply randomly convert image to grayscale with the assigned probability. Default value is 0.2. - This function follows the Pytorch documentation: https://pytorch.org/vision/stable/transforms.html """ self.strength = strength self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue - self.color_jitter_probability=color_jitter_probability - self.grayscale_probability=grayscale_probability + self.color_jitter_probability = color_jitter_probability + self.grayscale_probability = grayscale_probability self.color_jitter = pth_transforms.ColorJitter( self.brightness * self.strength, self.contrast * self.strength, self.saturation * self.strength, self.hue * self.strength, ) - self.rnd_color_jitter = pth_transforms.RandomApply([self.color_jitter], p=self.color_jitter_probability) + self.rnd_color_jitter = pth_transforms.RandomApply( + [self.color_jitter], p=self.color_jitter_probability + ) self.rnd_gray = pth_transforms.RandomGrayscale(p=self.grayscale_probability) self.transforms = pth_transforms.Compose([self.rnd_color_jitter, self.rnd_gray]) diff --git a/vissl/hooks/__init__.py b/vissl/hooks/__init__.py index 41e3276c6..9f19cbc8a 100644 --- a/vissl/hooks/__init__.py +++ b/vissl/hooks/__init__.py @@ -8,6 +8,7 @@ from classy_vision.hooks.classy_hook import ClassyHook from vissl.config import AttrDict +from vissl.hooks.byol_hooks import BYOLHook # noqa from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa from vissl.hooks.dino_hooks import DINOHook from vissl.hooks.grad_clip_hooks import GradClipHook # noqa @@ -21,15 +22,12 @@ ) from vissl.hooks.moco_hooks import MoCoHook # noqa from vissl.hooks.profiling_hook import ProfilingHook -from vissl.hooks.byol_hooks import BYOLHook # noqa - from vissl.hooks.state_update_hooks import ( # noqa CheckNanLossHook, FreezeParametersHook, SetDataSamplerEpochHook, SSLModelComplexityHook, ) -from vissl.hooks.byol_hooks import BYOLHook # noqa from vissl.hooks.swav_hooks import NormalizePrototypesHook # noqa from vissl.hooks.swav_hooks import SwAVUpdateQueueScoresHook # noqa from vissl.hooks.swav_momentum_hooks import ( @@ -149,14 +147,6 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]: ) ] ) - if cfg.LOSS.name == "byol_loss": - hooks.extend( - [ - BYOLHook( - cfg.LOSS["byol_loss"]["momentum"], - ) - ] - ) if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY: hooks.extend([SSLModelComplexityHook()]) if cfg.HOOKS.LOG_GPU_STATS: diff --git a/vissl/hooks/byol_hooks.py b/vissl/hooks/byol_hooks.py index 12c266184..c1536f928 100644 --- a/vissl/hooks/byol_hooks.py +++ b/vissl/hooks/byol_hooks.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import math import logging +import math import torch from classy_vision import tasks @@ -8,14 +8,15 @@ from vissl.models import build_model from vissl.utils.env import get_machine_local_and_dist_rank + class BYOLHook(ClassyHook): """ - BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733) - is based on Contrastive learning. This hook - creates a target network with the same architecture - as the main online network, but without the projection head. - The online network does not participate in backpropogation, - but instead is an exponential moving average of the online network. + BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733) + is based on Contrastive learning. This hook + creates a target network with the same architecture + as the main online network, but without the projection head. + The online network does not participate in backpropogation, + but instead is an exponential moving average of the online network. """ on_start = ClassyHook._noop @@ -28,7 +29,7 @@ class BYOLHook(ClassyHook): on_update = ClassyHook._noop @staticmethod - def cosine_decay(training_iter, max_iters, initial_value) -> float: + def cosine_decay(training_iter, max_iters, initial_value) -> float: """ For a given starting value, this function anneals the learning rate. @@ -42,8 +43,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float: """ Updates Exponential Moving average of the Target Network. """ - decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.) - return 1. - (1. - base_ema) * decay + decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.0) + return 1.0 - (1.0 - base_ema) * decay def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: """ @@ -53,19 +54,19 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: """ # Create the encoder, which will slowly track the model logging.info( - "BYOL: Building BYOL target network - rank %s %s", *get_machine_local_and_dist_rank() + "BYOL: Building BYOL target network - rank %s %s", + *get_machine_local_and_dist_rank(), ) - # Target model has the same architecture, but without the projector head. - target_model_config = task.config['MODEL'] - target_model_config['HEAD']['PARAMS'] = target_model_config['HEAD']['PARAMS'][0:1] + # Target model has the same architecture, *without* the projector head. + target_model_config = task.config["MODEL"] + target_model_config["HEAD"]["PARAMS"] = target_model_config["HEAD"]["PARAMS"][ + 0:1 + ] task.loss.target_network = build_model( target_model_config, task.config["OPTIMIZER"] ) - # TESTED: Target Network and Online network are properly created. - # TODO: Check SyncBatchNorm settings (low prior) - task.loss.target_network.to(task.device) # Restore an hypothetical checkpoint, else copy the model parameters from the @@ -73,7 +74,9 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: if task.loss.checkpoint is not None: task.loss.load_state_dict(task.loss.checkpoint) else: - logging.info("BYOL: Copying and freezing model parameters from online to target network") + logging.info( + "BYOL: Copying and freezing model parameters from online to target network" + ) for param_q, param_k in zip( task.base_model.parameters(), task.loss.target_network.parameters() ): @@ -92,7 +95,9 @@ def _update_momentum_coefficient(self, task: tasks.ClassyTask) -> None: self.total_iters = task.max_iteration logging.info(f"{self.total_iters} total iters") training_iteration = task.iteration - self.momentum = self.target_ema(training_iteration, self.base_momentum, self.total_iters) + self.momentum = self.target_ema( + training_iteration, self.base_momentum, self.total_iters + ) @torch.no_grad() def _update_target_network(self, task: tasks.ClassyTask) -> None: @@ -106,10 +111,10 @@ def _update_target_network(self, task: tasks.ClassyTask) -> None: task.base_model.parameters(), task.loss.target_network.parameters() ): target_params.data = ( - target_params.data * self.momentum + online_params.data * (1. - self.momentum) + target_params.data * self.momentum + + online_params.data * (1.0 - self.momentum) ) - @torch.no_grad() def on_forward(self, task: tasks.ClassyTask) -> None: """ @@ -127,9 +132,8 @@ def on_forward(self, task: tasks.ClassyTask) -> None: else: self._update_target_network(task) - # Compute target network embeddings - batch = task.last_batch.sample['input'] + batch = task.last_batch.sample["input"] target_embs = task.loss.target_network(batch)[0] # Save target embeddings to use them in the loss diff --git a/vissl/losses/byol_loss.py b/vissl/losses/byol_loss.py index b0fbdfcca..03581356b 100644 --- a/vissl/losses/byol_loss.py +++ b/vissl/losses/byol_loss.py @@ -7,9 +7,9 @@ import torch.nn.functional as F from classy_vision.losses import ClassyLoss, register_loss -_BYOLLossConfig = namedtuple( - "_BYOLLossConfig", ["embedding_dim", "momentum"] -) + +_BYOLLossConfig = namedtuple("_BYOLLossConfig", ["embedding_dim", "momentum"]) + def regression_loss(x, y): """ @@ -19,17 +19,16 @@ def regression_loss(x, y): Cosine similarity. This implementation uses Cosine similarity. """ normed_x, normed_y = F.normalize(x, dim=1), F.normalize(y, dim=1) - return torch.sum((normed_x - normed_y).pow(2), dim=1) + # Euclidean Distance squared. + return 2 - 2 * (normed_x * normed_y).sum(dim=1) class BYOLLossConfig(_BYOLLossConfig): - """ Settings for the BYOL loss""" + """Settings for the BYOL loss""" @staticmethod def defaults() -> "BYOLLossConfig": - return BYOLLossConfig( - embedding_dim=256, momentum=0.999 - ) + return BYOLLossConfig(embedding_dim=256, momentum=0.999) @register_loss("byol_loss") @@ -68,7 +67,9 @@ def from_config(cls, config: BYOLLossConfig) -> "BYOLLoss": """ return cls(config) - def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward( + self, online_network_prediction: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """ In this function, the Online Network receives the tensor as input after projection and they make predictions on the output of the target network’s projection, @@ -79,7 +80,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t compute the cross entropy loss for this batch. Args: - query: output of the encoder given the current batch online_network_prediction: online model output. this is a prediction of the target network output. @@ -91,8 +91,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t online_view1, online_view2 = torch.chunk(online_network_prediction, 2, 0) target_view1, target_view2 = torch.chunk(self.target_embs.detach(), 2, 0) - # TESTED: Views are received correctly. - # Compute losses loss1 = regression_loss(online_view1, target_view2) loss2 = regression_loss(online_view2, target_view1) @@ -111,7 +109,6 @@ def load_state_dict(self, state_dict, *args, **kwargs) -> None: Args: state_dict (serialized via torch.save) """ - # If the encoder has been allocated, use the normal pytorch restoration if self.target_network is None: self.checkpoint = state_dict