From a91cfd5c6cf198ab7b4a8495cfeda3e78092a91f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 11 Jul 2024 21:47:46 +0800 Subject: [PATCH] Support concat of clip outputs for longer prompt --- gradio_app.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/gradio_app.py b/gradio_app.py index 831396b..a7d5eb7 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -86,6 +86,44 @@ def encode_cropped_prompt_77tokens(txt: str): return text_cond +@torch.inference_mode() +def encode_cropped_prompt(txt: str, max_length=225): + memory_management.load_models_to_gpu(text_encoder) + cond_ids = tokenizer( + txt, + padding="max_length", + max_length=max_length + 2, + truncation=True, + return_tensors="pt", + ).input_ids.to(device=text_encoder.device) + if max_length + 2 > tokenizer.model_max_length: + input_ids = cond_ids.squeeze(0) + id_list = list(range(1, max_length + 2 - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2)) + text_cond_list = [] + for i in id_list: + # Encode each chunk than concatenate their result + ids_chunk = ( + input_ids[0].unsqueeze(0), + input_ids[i : i + tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0), + ) + if torch.all(ids_chunk[1] == tokenizer.pad_token_id): + break + text_cond = text_encoder(torch.concat(ids_chunk).unsqueeze(0)).last_hidden_state + if text_cond_list == []: + # BOS token + text_cond_list.append(text_cond[:, :1]) + text_cond_list.append(text_cond[:, 1:tokenizer.model_max_length - 1]) + # EOS token + text_cond_list.append(text_cond[:, -1:]) + text_cond = torch.concat(text_cond_list, dim=1) + else: + text_cond = text_encoder( + cond_ids, attention_mask=None + ).last_hidden_state + return text_cond.flatten(0, 1).unsqueeze(0) + + @torch.inference_mode() def pytorch2numpy(imgs): results = [] @@ -126,7 +164,7 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed, concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor memory_management.load_models_to_gpu(text_encoder) - conds = encode_cropped_prompt_77tokens(prompt) + conds = encode_cropped_prompt(prompt) unconds = encode_cropped_prompt_77tokens(n_prompt) memory_management.load_models_to_gpu(unet)