Skip to content

Commit

Permalink
Update pre_length.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhuLvs authored Nov 27, 2024
1 parent 886255c commit 5ecbeb7
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model/pre_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def remove_module_prefix(state_dict):
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Multitasking().to(device)
state_dict = (torch.load("../result_length/best_model.pth", map_location=device))
state_dict = (torch.load("./result_length/best_model.pth", map_location=device))
new_state_dict = remove_module_prefix(state_dict)
model.load_state_dict(new_state_dict)

Expand All @@ -47,6 +47,6 @@ def remove_module_prefix(state_dict):
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_folder = "../input/testA"
output_folder = "../input/out_length"
img_folder = "./input/testA"
output_folder = "./input/out_length"
predict_and_save(model, img_folder, val_transform, device, output_folder)

0 comments on commit 5ecbeb7

Please sign in to comment.