From d79a54a4beaf268d445f54939cdc35120f79f360 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 3 Apr 2019 11:07:31 +0200 Subject: [PATCH] updates --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 76df0a8415..f5712b21a3 100644 --- a/train.py +++ b/train.py @@ -53,12 +53,12 @@ def train( yl = get_yolo_layers(model) # yolo layers nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255) - if resume: # Load previously saved PyTorch model + if resume: # Load previously saved model if transfer: # Transfer learning chkpt = torch.load(weights + 'yolov3.pt', map_location=device) - model.load_state_dict( - {k: v for k, v in chkpt['model'].items() if (int(k.split('.')[1]) + 1) not in yl}, strict=False) - for (name, p) in model.named_parameters(): + model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != nf}, + strict=False) + for p in model.parameters(): p.requires_grad = True if p.shape[0] == nf else False else: # resume from latest.pt