diff --git a/src/cryo_sbi/inference/command_line_tools.py b/src/cryo_sbi/inference/command_line_tools.py index 6e77bd3..8a1a491 100644 --- a/src/cryo_sbi/inference/command_line_tools.py +++ b/src/cryo_sbi/inference/command_line_tools.py @@ -37,6 +37,12 @@ def cl_npe_train_no_saving(): cl_parser.add_argument( "--saving_freq", action="store", type=int, required=False, default=20 ) + cl_parser.add_argument( + "--val_set", action="store", type=str, required=False, default=None + ) + cl_parser.add_argument( + "--val_freq", action="store", type=int, required=False, default=10 + ) cl_parser.add_argument( "--simulation_batch_size", action="store", @@ -59,4 +65,6 @@ def cl_npe_train_no_saving(): device=args.train_device, saving_frequency=args.saving_freq, simulation_batch_size=args.simulation_batch_size, + validation_set=args.val_set, + validation_frequency=args.val_freq, ) diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index a699ca5..2601c1c 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -3,13 +3,12 @@ import torch import numpy as np import torch.optim as optim -from torch.utils.data import TensorDataset -from torchvision import transforms +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from lampe.data import JointLoader, H5Dataset from lampe.inference import NPELoss from lampe.utils import GDStep from itertools import islice +import matplotlib.pyplot as plt from cryo_sbi.inference.priors import get_image_priors, PriorLoader from cryo_sbi.inference.models.build_models import build_npe_flow_model @@ -17,7 +16,7 @@ from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator from cryo_sbi.wpa_simulator.validate_image_config import check_image_params from cryo_sbi.inference.validate_train_config import check_train_params -import cryo_sbi.utils.image_utils as img_utils +from cryo_sbi.utils.estimator_utils import sample_posterior, evaluate_log_prob def load_model( @@ -56,6 +55,8 @@ def npe_train_no_saving( device: str = "cpu", saving_frequency: int = 20, simulation_batch_size: int = 1024, + validation_set: Union[str, None] = None, + validation_frequency: int = 10 ) -> None: """ Train NPE model by simulating training data on the fly. @@ -84,9 +85,12 @@ def npe_train_no_saving( train_config = json.load(open(train_config)) check_train_params(train_config) image_config = json.load(open(image_config)) + check_image_params(image_config) assert simulation_batch_size >= train_config["BATCH_SIZE"] assert simulation_batch_size % train_config["BATCH_SIZE"] == 0 + steps_per_epoch = simulation_batch_size // train_config["BATCH_SIZE"] + epoch_repeats = 100 # number of times to simulate a batch of images per epoch if image_config["MODEL_FILE"].endswith("npy"): models = ( @@ -104,6 +108,7 @@ def npe_train_no_saving( image_prior = get_image_priors(len(models) - 1, image_config, device="cpu") index_to_cv = image_prior.priors[0].index_to_cv.to(device) + max_index = index_to_cv.max().cpu() prior_loader = PriorLoader( image_prior, batch_size=simulation_batch_size, num_workers=n_workers ) @@ -120,18 +125,35 @@ def npe_train_no_saving( ) loss = NPELoss(estimator) - optimizer = optim.AdamW( - estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=0.001 - ) + optimizer = optim.AdamW(estimator.parameters(), lr=train_config["LEARNING_RATE"], weight_decay=train_config["WEIGHT_DECAY"]) step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) mean_loss = [] + if validation_set is not None: + validation_set = torch.load(validation_set) + assert isinstance(validation_set, dict), "Validation set must be a dictionary" + assert "IMAGES" in validation_set, "Validation set must contain images" + assert "INDICES" in validation_set, "Validation set must contain ground truth indices" + + print("Initializing tensorboard writer") + writer = SummaryWriter() + + if validation_set is not None: + num_validation_images = validation_set["IMAGES"].shape[0] + for i in range(num_validation_images): + fig, axes = plt.subplots(1, 1, figsize=(5, 5)) + axes.imshow(validation_set["IMAGES"][i].cpu().numpy(), cmap="gray", vmax=1.5, vmin=-1.5) + axes.axis("off") + writer.add_figure(f"Validation/images", fig, global_step=i) + plt.close(fig) + writer.flush() + print("Training neural netowrk:") estimator.train() with tqdm(range(epochs), unit="epoch") as tq: for epoch in tq: losses = [] - for parameters in islice(prior_loader, 100): + for parameters in islice(prior_loader, epoch_repeats): ( indices, quaternions, @@ -171,8 +193,34 @@ def npe_train_no_saving( tq.set_postfix(loss=losses.mean().item()) mean_loss.append(losses.mean().item()) + current_step = (epoch + 1) * steps_per_epoch * epoch_repeats + + writer.add_scalar("Loss/mean", losses.mean().item(), current_step) + writer.add_scalar("Loss/std", losses.std().item(), current_step) + writer.add_scalar("Loss/last", losses[-1].item(), current_step) + if epoch % saving_frequency == 0: torch.save(estimator.state_dict(), estimator_file + f"_epoch={epoch}") + if validation_set is not None and epoch % validation_frequency == 0: + estimator.eval() + with torch.no_grad(): + val_posterior_samples = sample_posterior( + estimator, validation_set["IMAGES"], num_samples=5000, device=device, batch_size=train_config["BATCH_SIZE"] + ) + for i in range(num_validation_images): + writer.add_histogram( + f"Validation/posterior_{i}_index={validation_set['INDICES'][i].item()}", + val_posterior_samples[:, i], + global_step=current_step + ) + estimator.train() + + writer.add_hparams( + train_config, + {"hparam/best_loss": min(mean_loss), "hparam/last_loss": mean_loss[-1]} + ) + writer.flush() + writer.close() torch.save(estimator.state_dict(), estimator_file) torch.save(torch.tensor(mean_loss), loss_file) diff --git a/src/cryo_sbi/inference/validate_train_config.py b/src/cryo_sbi/inference/validate_train_config.py index 76a8c9b..61673f1 100644 --- a/src/cryo_sbi/inference/validate_train_config.py +++ b/src/cryo_sbi/inference/validate_train_config.py @@ -21,6 +21,7 @@ def check_train_params(config: dict) -> None: "BATCH_SIZE", "THETA_SHIFT", "THETA_SCALE", + "WEIGHT_DECAY" ] for key in needed_keys: