diff --git a/src/cryo_sbi/inference/train_npe_model.py b/src/cryo_sbi/inference/train_npe_model.py index 732a08b..612f990 100644 --- a/src/cryo_sbi/inference/train_npe_model.py +++ b/src/cryo_sbi/inference/train_npe_model.py @@ -17,6 +17,7 @@ from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator from cryo_sbi.wpa_simulator.validate_image_config import check_image_params from cryo_sbi.inference.validate_train_config import check_train_params +import cryo_sbi.utils.image_utils as img_utils def load_model( @@ -121,7 +122,7 @@ def npe_train_no_saving( ) step = GDStep(optimizer, clip=train_config["CLIP_GRADIENT"]) mean_loss = [] - + print("Training neural netowrk:") estimator.train() with tqdm(range(epochs), unit="epoch") as tq: