From c556b0de8618118247e3e1c4e3c64dbd91ee71e7 Mon Sep 17 00:00:00 2001 From: ZhuLvs Date: Thu, 28 Nov 2024 01:34:47 +0800 Subject: [PATCH] Update Length_main.py --- model/Length_main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/model/Length_main.py b/model/Length_main.py index e993b96..bde517b 100644 --- a/model/Length_main.py +++ b/model/Length_main.py @@ -23,14 +23,14 @@ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) best_val_loss = float('inf') - best_model_save_path = "../result_length/best_model.pth" - final_model_save_path = "../result_length/final_model.pth" + best_model_save_path = "./result_length/best_model.pth" + final_model_save_path = "./result_length/final_model.pth" num_epochs = 340 - train_dataset = MulDataset("../data/2DIR", "../data/contact", transform=transform) + train_dataset = MulDataset("./data/2DIR", "./data/contact", transform=transform) train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=14) - val_dataset = MulDataset("../data/valA", "../data/valB", transform=val_transform) + val_dataset = MulDataset("./data/valA", "./data/valB", transform=val_transform) val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=14) for epoch in range(num_epochs): @@ -47,4 +47,4 @@ print(f"Epoch [{epoch + 1}/{num_epochs}], Train Total Loss: {avg_train_loss_total:.4f}, Train Dist Loss: {avg_train_loss_dist:.4f}, Train Num Loss: {avg_train_loss_num:.4f}, Val Total Loss: {avg_val_loss_total:.4f}, Val Dist Loss: {avg_val_loss_dist:.4f}, Val Num Loss: {avg_val_loss_num:.4f}") - torch.save(model.state_dict(), final_model_save_path) \ No newline at end of file + torch.save(model.state_dict(), final_model_save_path)