You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
The weights from the download URLs for NPID and MoCoV2 appear to be the same. Perhaps a copying error?
The code below may be run to demonstrate the equivalence:
def get_vissl_model(weights_url):
from torch.hub import load_state_dict_from_url
weights = load_state_dict_from_url(weights_url, map_location = torch.device('cpu'))
def replace_module_prefix(state_dict, prefix, replace_with = ''):
return {(key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val
for (key, val) in state_dict.items()}
def convert_model_weights(model):
if "classy_state_dict" in model.keys():
model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"]
elif "model_state_dict" in model.keys():
model_trunk = model["model_state_dict"]
else:
model_trunk = model
return replace_module_prefix(model_trunk, "_feature_blocks.")
converted_weights = convert_model_weights(weights)
excess_weights = ['fc','projection', 'prototypes']
converted_weights = {key:value for (key,value) in converted_weights.items()
if not any([x in key for x in excess_weights])}
if 'module' in next(iter(converted_weights)):
converted_weights = {key.replace('module.',''):value for (key,value) in converted_weights.items()
if 'fc' not in key}
from torchvision.models import resnet50
import torch.nn as nn
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
model = resnet50()
model.fc = Identity()
model.load_state_dict(converted_weights)
return model
### NPID
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/npid_1node_200ep_4kneg_npid_8gpu_resnet_23_07_20.9eb36512/model_final_checkpoint_phase199.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])
### MoCoV2
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])
### BarlowTwins to show the difference
weights_url = 'https://dl.fbaipublicfiles.com/vissl/model_zoo/barlow_twins/barlow_twins_32gpus_4node_imagenet1k_1000ep_resnet50.torch'
model = get_vissl_model(weights_url)
print(model.parameters())[1:10,1,1,1])
The text was updated successfully, but these errors were encountered:
The weights from the download URLs for NPID and MoCoV2 appear to be the same. Perhaps a copying error?
The code below may be run to demonstrate the equivalence:
The text was updated successfully, but these errors were encountered: