Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Apr 3, 2019
1 parent c36f1e9 commit d79a54a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def train(
yl = get_yolo_layers(model) # yolo layers
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)

if resume: # Load previously saved PyTorch model
if resume: # Load previously saved model
if transfer: # Transfer learning
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
model.load_state_dict(
{k: v for k, v in chkpt['model'].items() if (int(k.split('.')[1]) + 1) not in yl}, strict=False)
for (name, p) in model.named_parameters():
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != nf},
strict=False)
for p in model.parameters():
p.requires_grad = True if p.shape[0] == nf else False

else: # resume from latest.pt
Expand Down

0 comments on commit d79a54a

Please sign in to comment.