diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py index 638221c..a8d5c1a 100644 --- a/q_align/train/train_mem.py +++ b/q_align/train/train_mem.py @@ -823,8 +823,8 @@ def make_inputs_require_grad(module, input, output): if training_args.freeze_vision_model: for p in model.get_model().vision_model.parameters(): p.requires_grad = False - - model.print_trainable_parameters() + if training_args.lora_enable: + model.print_trainable_parameters() model.config.visual_abstractor_lr = training_args.visual_abstractor_lr