Skip to content

Commit

Permalink
Update pre.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhuLvs authored Nov 27, 2024
1 parent 7ae0985 commit 6160375
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions model/pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def make_symmetric_and_zero_diag_numpy(matrix):
matrix_symmetric = (matrix_abs + matrix_abs.T) / 2
return matrix_symmetric


def predict_and_save(model, img_folder, transform, device, output_folder):
model.eval()
with torch.no_grad():
Expand All @@ -34,6 +33,14 @@ def remove_module_prefix(state_dict):
return {k.replace('module.', ''): v for k, v in state_dict.items()}

if __name__ == '__main__':
# Set up directories
val_folder = "./data/valA"
output_folder = "./data/out"

# Create output folder if it doesn't exist
if not os.path.exists(output_folder):
os.makedirs(output_folder)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CustomDeepLabV3().to(device)

Expand All @@ -47,7 +54,5 @@ def remove_module_prefix(state_dict):
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_folder = "./input/valA"
output_folder = "./input/out"
predict_and_save(model, img_folder, val_transform, device, output_folder)

# Call predict and save with the updated folder paths
predict_and_save(model, val_folder, val_transform, device, output_folder)

0 comments on commit 6160375

Please sign in to comment.