diff --git a/hear21passt/models/passt.py b/hear21passt/models/passt.py index 5d5a424..50f2f7d 100644 --- a/hear21passt/models/passt.py +++ b/hear21passt/models/passt.py @@ -169,6 +169,10 @@ def _cfg(url='', **kwargs): url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, classifier=('head.1', 'head_dist'), num_classes=527), + 'passt_l_kd_p16_128_ap47': _cfg( + url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.10/passt-l-kd-ap.47.pt', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, + classifier=('head.1', 'head_dist'), num_classes=527), 'passt_s_p16_s16_128_ap468': _cfg( url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0, @@ -726,6 +730,18 @@ def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs): 'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs) return model +def passt_l_kd_p16_128_ap47(pretrained=False, **kwargs): + """ PaSST pre-trained on AudioSet + """ + print("\n\n Loading PaSST-L (light, reduced depth=7) pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=4708 \n\n") + model_kwargs = dict(patch_size=16, embed_dim=768, depth=7, num_heads=12, **kwargs) + if model_kwargs.get("stride") != (10, 10): + warnings.warn( + f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.") + model = _create_vision_transformer( + 'passt_l_kd_p16_128_ap47', pretrained=pretrained, distilled=True, **model_kwargs) + return model + def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs): """ PaSST pre-trained on AudioSet """ @@ -866,6 +882,8 @@ def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, i model_func = passt_s_swa_p16_128_ap476 elif arch == "passt_s_kd_p16_128_ap486": # pretrained model_func = passt_s_kd_p16_128_ap486 + elif arch == "passt_l_kd_p16_128_ap47": # pretrained passt-L + model_func = passt_l_kd_p16_128_ap47 elif arch == "passt_s_p16_s16_128_ap468": if fstride!=16 or tstride!=16: raise ValueError("fstride and tstride must be 16 for arch=passt_s_p16_s16_128_ap468. " diff --git a/test.py b/test.py index 838720f..58dd552 100644 --- a/test.py +++ b/test.py @@ -31,3 +31,5 @@ model.net = get_model_passt("stfthop160", input_tdim=2000) model.net = get_model_passt("passt_20sec", input_tdim=2000) model.net = get_model_passt("passt_30sec", input_tdim=3000) + + model.net = get_model_passt("passt_l_kd_p16_128_ap47")