Skip to content

Commit

Permalink
model component merge
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Jan 19, 2025
1 parent de10632 commit 311d402
Show file tree
Hide file tree
Showing 11 changed files with 836 additions and 248 deletions.
13 changes: 12 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@

## Update for 2025-01-19

- **Model Merge**
- replace model components and merge LoRAs
in addition to existing model weights merge support
now also having ability to replace model components and merge LoRAs
you can also test merges in-memory without needing to save to disk at all
and you can also use it to convert diffusers to safetensors if you want
*example*: replace vae in your favorite model with a fixed one? replace text encoder? etc.
*note*: limited to sdxl for now, additional models can be added depending on popularity
- **Detailer**:
- in addition as standard behavior of detect & run-generate, it can now also run face-restore models
- included models are: *CodeFormer, RestoreFormer, GFPGan, GPEN-BFR*
- **Other**:
- **ipex**: update supported torch versions
- **gallery**: add http fallback for slow/unreliable links
- **upscale**: code refactor to unify latent, resize and model based upscalers
- **splash**: add legacy mode indicator on splash screen
- **network**: extract thumbnail from model metadata if present
- **Refactor**:
- **upscale**: code refactor to unify latent, resize and model based upscalers
- **loader**: ability to run in-memory models
- **schedulers**: ability to create model-less schedulers
- **Fixes**:
- non-full vae decode
- send-to image transfer
Expand Down
2 changes: 1 addition & 1 deletion javascript/base.css
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
#extensions .info { margin: 0; }
#extensions .date { opacity: 0.85; font-size: 90%; }

/* extra networks */
/* networks */
.extra-networks > div { margin: 0; border-bottom: none !important; }
.extra-networks .second-line { display: flex; width: -moz-available; width: -webkit-fill-available; gap: 0.3em; box-shadow: var(--input-shadow); margin-bottom: 2px; }
.extra-networks .search { flex: 1; }
Expand Down
7 changes: 5 additions & 2 deletions javascript/sdnext.css
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
#extensions .info { margin: 0; }
#extensions .date { opacity: 0.85; font-size: var(--text-sm); }

/* extra networks */
/* networks */
.extra_networks_root { width: 0; position: absolute; height: auto; right: 0; top: 13em; z-index: 100; } /* default is sidebar view */
.extra-networks { background: var(--background-color); padding: var(--block-label-padding); }
.extra-networks > div { margin: 0; border-bottom: none !important; gap: 0.3em 0; }
Expand Down Expand Up @@ -269,11 +269,14 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.ar-dropdown div { margin: 0; background: var(--background-color)}
#txt2img_sampler_timesteps, #img2img_sampler_timesteps { max-width: calc(var(--left-column) - 50px); }

/* extras */
/* models */
.extras { gap: 0.2em 1em !important }
#extras_generate, #extras_interrupt, #extras_skip { display: block !important; position: relative; height: 36px; }
#extras_upscale { margin-top: 10px }
#pnginfo_html_info .gradio-html > div { margin: 0.5em; }
#models_image, #models_image > div { min-height: 0; }
#models_error { font-family: monospace; color: var(--body-text-color-subdued) }


/* log monitor */
.log-monitor { display: none; justify-content: unset !important; overflow: hidden; padding: 0; margin-top: auto; font-family: monospace; font-size: var(--text-xxs); }
Expand Down
222 changes: 78 additions & 144 deletions modules/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
import time
import shutil

from PIL import Image
import torch
import tqdm
import gradio as gr
import safetensors.torch
from modules.merging.merge import merge_models
from modules.merging.merge_utils import TRIPLE_METHODS

from modules import shared, images, sd_models, sd_vae, sd_models_config, devices
from modules.merging import merge, merge_utils, modules_sdxl
from modules import shared, images, sd_models, sd_vae, sd_samplers, sd_models_config, devices


def run_pnginfo(image):
Expand Down Expand Up @@ -73,9 +71,9 @@ def fail(message):
if kwargs.get("secondary_model_name", None) in [None, 'None']:
return fail("Failed: Merging requires a secondary model.")
secondary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("secondary_model_name", None))
if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in TRIPLE_METHODS:
if kwargs.get("tertiary_model_name", None) in [None, 'None'] and kwargs.get("merge_mode", None) in merge_utils.TRIPLE_METHODS:
return fail(f"Failed: Interpolation method ({kwargs.get('merge_mode', None)}) requires a tertiary model.")
tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in TRIPLE_METHODS else None
tertiary_model_info = sd_models.get_closet_checkpoint_match(kwargs.get("tertiary_model_name", None)) if kwargs.get("merge_mode", None) in merge_utils.TRIPLE_METHODS else None

del kwargs["primary_model_name"]
del kwargs["secondary_model_name"]
Expand Down Expand Up @@ -128,7 +126,7 @@ def fail(message):
sd_models.unload_model_weights()

try:
theta_0 = merge_models(**kwargs)
theta_0 = merge.merge_models(**kwargs)
except Exception as e:
return fail(f"{e}")

Expand Down Expand Up @@ -205,144 +203,80 @@ def add_model_metadata(checkpoint_info):
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model saved to {output_modelname}"]


def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv,
vae_conv, others_conv, fix_clip):
# position_ids in clip is int64. model_ema.num_updates is int32
dtypes_to_fp16 = {torch.float32, torch.float64, torch.bfloat16}
dtypes_to_bf16 = {torch.float32, torch.float64, torch.float16}
def run_model_modules(model_type:str, model_name:str, custom_name:str,
comp_unet:str, comp_vae:str, comp_te1:str, comp_te2:str,
precision:str, comp_scheduler:str, comp_prediction:str,
comp_lora:str, comp_fuse:float,
meta_author:str, meta_version:str, meta_license:str, meta_desc:str, meta_hint:str, meta_thumbnail:Image.Image,
create_diffusers:bool, create_safetensors:bool, debug:bool):

def conv_fp16(t: torch.Tensor):
return t.half() if t.dtype in dtypes_to_fp16 else t
status = ''
def msg(text, err:bool=False):
nonlocal status
if err:
shared.log.error(f'Modules merge: {text}')
else:
shared.log.info(f'Modules merge: {text}')
status += text + '<br>'
return status

def conv_bf16(t: torch.Tensor):
return t.bfloat16() if t.dtype in dtypes_to_bf16 else t
if model_type != 'sdxl':
yield msg("only SDXL models are supported", err=True)
return
if len(custom_name) == 0:
yield msg("output name is required", err=True)
return
checkpoint_info = sd_models.get_closet_checkpoint_match(model_name)
if checkpoint_info is None:
yield msg("input model not found", err=True)
return
fn = checkpoint_info.filename
shared.state.begin('Merge')
yield msg("modules merge starting")
yield msg("unload current model")
sd_models.unload_model_weights(op='model')

modules_sdxl.recipe.name = custom_name
modules_sdxl.recipe.author = meta_author
modules_sdxl.recipe.version = meta_version
modules_sdxl.recipe.desc = meta_desc
modules_sdxl.recipe.hint = meta_hint
modules_sdxl.recipe.license = meta_license
modules_sdxl.recipe.thumbnail = meta_thumbnail
modules_sdxl.recipe.base = fn
modules_sdxl.recipe.unet = comp_unet
modules_sdxl.recipe.vae = comp_vae
modules_sdxl.recipe.te1 = comp_te1
modules_sdxl.recipe.te2 = comp_te2
modules_sdxl.recipe.prediction = comp_prediction
modules_sdxl.recipe.diffusers = create_diffusers
modules_sdxl.recipe.safetensors = create_safetensors
modules_sdxl.recipe.fuse = float(comp_fuse)
modules_sdxl.recipe.debug = debug

loras = [l.strip() if ':' in l else f'{l.strip()}:1.0' for l in comp_lora.split(',') if len(l.strip()) > 0]
for lora, strength in [l.split(':') for l in loras]:
modules_sdxl.recipe.lora[lora] = float(strength)
scheduler = sd_samplers.create_sampler(comp_scheduler, None)
modules_sdxl.recipe.scheduler = scheduler.__class__.__name__ if scheduler is not None else None
if precision == 'fp32':
modules_sdxl.recipe.precision = torch.float32
elif precision == 'bf16':
modules_sdxl.recipe.precision = torch.bfloat16
else:
modules_sdxl.recipe.precision = torch.float16

def conv_full(t):
return t
modules_sdxl.status = status
yield from modules_sdxl.merge()
status = modules_sdxl.status

_g_precision_func = {
"full": conv_full,
"fp32": conv_full,
"fp16": conv_fp16,
"bf16": conv_bf16,
}

def check_weight_type(k: str) -> str:
if k.startswith("model.diffusion_model"):
return "unet"
elif k.startswith("first_stage_model"):
return "vae"
elif k.startswith("cond_stage_model"):
return "clip"
return "other"

def load_model(path):
if path.endswith(".safetensors"):
m = safetensors.torch.load_file(path, device="cpu")
else:
m = torch.load(path, map_location="cpu")
state_dict = m["state_dict"] if "state_dict" in m else m
return state_dict

def fix_model(model, fix_clip=False):
# code from model-toolkit
nai_keys = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
}
for k in list(model.keys()):
for r in nai_keys:
if type(k) == str and k.startswith(r):
new_key = k.replace(r, nai_keys[r])
model[new_key] = model[k]
del model[k]
shared.log.warning(f"Model convert: fixed NovelAI error key: {k}")
break
if fix_clip:
i = "cond_stage_model.transformer.text_model.embeddings.position_ids"
if i in model:
correct = torch.Tensor([list(range(77))]).to(torch.int64)
now = model[i].to(torch.int64)

broken = correct.ne(now)
broken = [i for i in range(77) if broken[0][i]]
model[i] = correct
if len(broken) != 0:
shared.log.warning(f"Model convert: fixed broken CLiP: {broken}")

return model

if model == "":
return "Error: you must choose a model"
if len(checkpoint_formats) == 0:
return "Error: at least choose one model save format"

extra_opt = {
"unet": unet_conv,
"clip": text_encoder_conv,
"vae": vae_conv,
"other": others_conv
}
shared.state.begin('Convert')
model_info = sd_models.checkpoints_list[model]
shared.state.textinfo = f"Load {model_info.filename}..."
shared.log.info(f"Model convert loading: {model_info.filename}")
state_dict = load_model(model_info.filename)

ok = {} # {"state_dict": {}}

conv_func = _g_precision_func[precision]

def _hf(wk: str, t: torch.Tensor):
if not isinstance(t, torch.Tensor):
return
w_t = check_weight_type(wk)
conv_t = extra_opt[w_t]
if conv_t == "convert":
ok[wk] = conv_func(t)
elif conv_t == "copy":
ok[wk] = t
elif conv_t == "delete":
return

shared.log.info("Model convert: running")
if conv_type == "ema-only":
for k in tqdm.tqdm(state_dict):
ema_k = "___"
try:
ema_k = "model_ema." + k[6:].replace(".", "")
except Exception:
pass
if ema_k in state_dict:
_hf(k, state_dict[ema_k])
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
_hf(k, state_dict[k])
elif conv_type == "no-ema":
for k, v in tqdm.tqdm(state_dict.items()):
if "model_ema." not in k:
_hf(k, v)
else:
for k, v in tqdm.tqdm(state_dict.items()):
_hf(k, v)

ok = fix_model(ok, fix_clip=fix_clip)
output = ""
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
save_name = f"{model_info.model_name}-{precision}"
if conv_type != "disabled":
save_name += f"-{conv_type}"
if custom_name != "":
save_name = custom_name
for fmt in checkpoint_formats:
ext = ".safetensors" if fmt == "safetensors" else ".ckpt"
_save_name = save_name + ext
save_path = os.path.join(ckpt_dir, _save_name)
shared.log.info(f"Model convert saving: {save_path}")
if fmt == "safetensors":
safetensors.torch.save_file(ok, save_path)
else:
torch.save({"state_dict": ok}, save_path)
output += f"Checkpoint saved to {save_path}<br>"
devices.torch_gc(force=True)
yield msg("modules merge complete")
if modules_sdxl.pipeline is not None:
checkpoint_info = sd_models.CheckpointInfo(filename='None')
shared.sd_model = modules_sdxl.pipeline
sd_models.set_defaults(shared.sd_model, checkpoint_info)
sd_models.set_diffuser_options(shared.sd_model, offload=False)
sd_models.set_diffuser_offload(shared.sd_model)
yield msg("pipeline loaded")
shared.state.end()
return output
Loading

0 comments on commit 311d402

Please sign in to comment.