diff --git a/model/pre_length.py b/model/pre_length.py index 0ca8949..6795edd 100644 --- a/model/pre_length.py +++ b/model/pre_length.py @@ -38,7 +38,7 @@ def remove_module_prefix(state_dict): if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Multitasking().to(device) - state_dict = (torch.load("../result_length/best_model.pth", map_location=device)) + state_dict = (torch.load("./result_length/best_model.pth", map_location=device)) new_state_dict = remove_module_prefix(state_dict) model.load_state_dict(new_state_dict) @@ -47,6 +47,6 @@ def remove_module_prefix(state_dict): transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) - img_folder = "../input/testA" - output_folder = "../input/out_length" + img_folder = "./input/testA" + output_folder = "./input/out_length" predict_and_save(model, img_folder, val_transform, device, output_folder)