From 08f1a592ea59293ff3bc9a22a33a21febe740f31 Mon Sep 17 00:00:00 2001 From: Raphael Glon Date: Fri, 25 Oct 2024 14:40:01 +0200 Subject: [PATCH] Diffusers, txt2img and img2img, make sure guidance scale defaults to 0 when num steps <=4 Signed-off-by: Raphael Glon --- docker_images/diffusers/app/pipelines/image_to_image.py | 6 ++++++ docker_images/diffusers/app/pipelines/text_to_image.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/docker_images/diffusers/app/pipelines/image_to_image.py b/docker_images/diffusers/app/pipelines/image_to_image.py index d8fcffc2..97359bc6 100644 --- a/docker_images/diffusers/app/pipelines/image_to_image.py +++ b/docker_images/diffusers/app/pipelines/image_to_image.py @@ -233,6 +233,12 @@ def _process_req(self, image, prompt, **kwargs): "negative_prompt": kwargs.get("negative_prompt", None), "guidance_scale": kwargs.get("guidance_scale", 7), } + if "guidance_scale" not in kwargs: + default_guidance_scale = os.getenv("DEFAULT_GUIDANCE_SCALE") + if default_guidance_scale is not None: + kwargs["guidance_scale"] = float(default_guidance_scale) + prior_args["guidance_scale"] = float(default_guidance_scale) + # Else, don't specify anything, leave the default behaviour image_emb, zero_image_emb = self.prior(prompt, **prior_args).to_tuple() images = self.ldm( prompt, diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index 7fafe10b..4dcdc613 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -169,6 +169,11 @@ def _process_req(self, inputs, **kwargs): kwargs["num_inference_steps"] = 20 # Else, don't specify anything, leave the default behaviour + if "guidance_scale" not in kwargs: + default_guidance_scale = os.getenv("DEFAULT_GUIDANCE_SCALE") + if default_guidance_scale is not None: + kwargs["guidance_scale"] = float(default_guidance_scale) + # Else, don't specify anything, leave the default behaviour if "seed" in kwargs: seed = int(kwargs["seed"]) generator = torch.Generator().manual_seed(seed)