Skip to content

Commit

Permalink
add passt-L pre-trained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Jun 6, 2023
1 parent a271296 commit d9e1ce4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
18 changes: 18 additions & 0 deletions hear21passt/models/passt.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _cfg(url='', **kwargs):
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_l_kd_p16_128_ap47': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s16_128_ap468': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
Expand Down Expand Up @@ -726,6 +730,18 @@ def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
return model

def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=7, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs)
return model

def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
Expand Down Expand Up @@ -866,6 +882,8 @@ def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, i
model_func = passt_s_swa_p16_128_ap476
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
model_func = passt_s_kd_p16_128_ap486
elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L
model_func = passt_l_kd_p16_128_ap47
elif arch == "passt_s_p16_s16_128_ap468":
if fstride!=16 or tstride!=16:
raise ValueError("fstride and tstride must be 16 for arch=passt_s_p16_s16_128_ap468. "
Expand Down
2 changes: 2 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
model.net = get_model_passt("stfthop160", input_tdim=2000)
model.net = get_model_passt("passt_20sec", input_tdim=2000)
model.net = get_model_passt("passt_30sec", input_tdim=3000)

model.net = get_model_passt("passt_l_kd_p16_128_ap47")

0 comments on commit d9e1ce4

Please sign in to comment.