-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconvert_ofa_original_ckpt_to_huggingface.py
57 lines (47 loc) · 1.78 KB
/
convert_ofa_original_ckpt_to_huggingface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import torch
from torch import nn
from ofa.configuration_ofa import OFAConfig
from ofa.modeling_ofa import OFAModel
from architecture_configs import ofa_base, ofa_large, ofa_tiny
def trans_fairseq_to_huggingface(fs_model, hf_model, config):
model = torch.load(fs_model, map_location='cpu')
state = model["model"]
keys = list(state.keys())
for k in keys:
if 'version' in k:
del state[k]
continue
new_k = k.replace('self_attn_ln', 'self_attn_mid_layer_norm').\
replace('ffn_layernorm', 'ffn_layer_norm').\
replace('cross_attn_ln', 'cross_attn_mid_layer_norm').\
replace('encoder_attn', 'cross_attn').\
replace('attn_ln', 'self_attn_mid_layer_norm')
v = state[k]
del state[k]
state[new_k] = v
model["model"] = state
remove_ignore_keys_(state)
ofa_config = OFAConfig(**config)
model = OFAModel(ofa_config)
model.load_state_dict(state)
model.save_pretrained(hf_model)
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"_float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Convert ofa original ckpt to huggingface.')
parser.add_argument('--pt_model', type=str, default='',
help='path of original ckpt')
parser.add_argument('--hf_model_dir', type=str, default='',
help='directory of huggingface ckpt')
args = parser.parse_args()
trans_fairseq_to_huggingface(args.pt_model, args.hf_model_dir, ofa_large)