Skip to content

Commit

Permalink
Update Length_main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhuLvs authored Nov 27, 2024
1 parent 5b8c56e commit c556b0d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions model/Length_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

best_val_loss = float('inf')
best_model_save_path = "../result_length/best_model.pth"
final_model_save_path = "../result_length/final_model.pth"
best_model_save_path = "./result_length/best_model.pth"
final_model_save_path = "./result_length/final_model.pth"
num_epochs = 340


train_dataset = MulDataset("../data/2DIR", "../data/contact", transform=transform)
train_dataset = MulDataset("./data/2DIR", "./data/contact", transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=14)
val_dataset = MulDataset("../data/valA", "../data/valB", transform=val_transform)
val_dataset = MulDataset("./data/valA", "./data/valB", transform=val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=14)

for epoch in range(num_epochs):
Expand All @@ -47,4 +47,4 @@

print(f"Epoch [{epoch + 1}/{num_epochs}], Train Total Loss: {avg_train_loss_total:.4f}, Train Dist Loss: {avg_train_loss_dist:.4f}, Train Num Loss: {avg_train_loss_num:.4f}, Val Total Loss: {avg_val_loss_total:.4f}, Val Dist Loss: {avg_val_loss_dist:.4f}, Val Num Loss: {avg_val_loss_num:.4f}")

torch.save(model.state_dict(), final_model_save_path)
torch.save(model.state_dict(), final_model_save_path)

0 comments on commit c556b0d

Please sign in to comment.