Skip to content

Commit

Permalink
rename decode_latents/encode_image -> image_to_latent/latent_to_image
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 29, 2024
1 parent 71f5c10 commit aaa9adc
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 50 deletions.
16 changes: 8 additions & 8 deletions src/refiners/foundationals/latent_diffusion/auto_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,25 +210,25 @@ def decode(self, x: Tensor) -> Tensor:
x = decoder(x / self.encoder_scale)
return x

def encode_image(self, image: Image.Image) -> Tensor:
return self.encode_images([image])
def image_to_latent(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])

def encode_images(self, images: list[Image.Image]) -> Tensor:
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
x = cat([image_to_tensor(image, device=self.device, dtype=self.dtype) for image in images], dim=0)
x = 2 * x - 1
return self.encode(x)


# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
# For retro-compatibilty
return self.decode_image(x)
return self.latent_to_image(x)

def decode_image(self, x: Tensor) -> Image.Image:
def latent_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

return self.decode_images(x)[0]

def decode_images(self, x: Tensor) -> list[Image.Image]:
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
x = self.decode(x)
x = (x + 1) / 2
return [tensor_to_image(t) for t in x.split(1)]
2 changes: 1 addition & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def init_latents(
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
encoded_image = self.lda.image_to_latent(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(
x=encoded_image,
noise=noise,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,15 @@ def device(self) -> Device:
def dtype(self) -> DType:
return self.ldm.dtype

# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
return self.latent_to_image(x=x)

def latent_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latent_to_image(x=x)

def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)

@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
latents = self.lda.image_to_latent(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)
Expand Down Expand Up @@ -202,7 +202,7 @@ def compute_evaluation(self) -> None:
step=step,
clip_text_embedding=clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
canvas_image.paste(sd.lda.latent_to_image(x=x), box=(0, 512 * i))
images[prompt] = canvas_image
self.log(data=images)

Expand Down
68 changes: 34 additions & 34 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_diffusion_std_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init)

Expand Down Expand Up @@ -666,7 +666,7 @@ def test_diffusion_std_random_init_euler(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init_euler)

Expand All @@ -691,7 +691,7 @@ def test_diffusion_karras_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -719,7 +719,7 @@ def test_diffusion_std_random_init_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -747,7 +747,7 @@ def test_diffusion_std_random_init_sag(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_random_init_sag)

Expand Down Expand Up @@ -776,7 +776,7 @@ def test_diffusion_std_init_image(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_std_init_image)

Expand All @@ -793,7 +793,7 @@ def test_rectangular_init_latents(
rect_init_image = cutecat_init.crop((0, 0, width, height))
x = sd15.init_latents((height, width), rect_init_image)

assert sd15.lda.decode_latents(x).size == (width, height)
assert sd15.lda.latent_to_image(x).size == (width, height)


@no_grad()
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_diffusion_inpainting(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
Expand Down Expand Up @@ -857,7 +857,7 @@ def test_diffusion_inpainting_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

# PSNR and SSIM values are large because float16 is even worse than float32.
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
Expand Down Expand Up @@ -900,7 +900,7 @@ def test_diffusion_controlnet(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -943,7 +943,7 @@ def test_diffusion_controlnet_structural_copy(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -985,7 +985,7 @@ def test_diffusion_controlnet_float16(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1039,7 +1039,7 @@ def test_diffusion_controlnet_stack(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1071,7 +1071,7 @@ def test_diffusion_lora(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1114,7 +1114,7 @@ def test_diffusion_sdxl_lora(
condition_scale=guidance_scale,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1162,7 +1162,7 @@ def test_diffusion_sdxl_multiple_loras(
condition_scale=guidance_scale,
)

predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)

Expand All @@ -1181,7 +1181,7 @@ def test_diffusion_refonly(

refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()

guide = sd15.lda.encode_image(condition_image_refonly)
guide = sd15.lda.image_to_latent(condition_image_refonly)
guide = torch.cat((guide, guide))

manual_seed(2)
Expand All @@ -1198,7 +1198,7 @@ def test_diffusion_refonly(
condition_scale=7.5,
)
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

# min_psnr lowered to 33 because this reference image was generated without noise removal (see #192)
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=33, min_ssim=0.99)
Expand All @@ -1223,7 +1223,7 @@ def test_diffusion_inpainting_refonly(
sd15.set_inference_steps(30)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)

guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
guide = sd15.lda.image_to_latent(scene_image_inpainting_refonly)
guide = torch.cat((guide, guide))

manual_seed(2)
Expand All @@ -1243,7 +1243,7 @@ def test_diffusion_inpainting_refonly(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)

Expand Down Expand Up @@ -1276,7 +1276,7 @@ def test_diffusion_textual_inversion_random_init(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1321,7 +1321,7 @@ def test_diffusion_ip_adapter(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)

Expand Down Expand Up @@ -1370,7 +1370,7 @@ def test_diffusion_sdxl_ip_adapter(
# See https://huggingface.co/madebyollin/sdxl-vae-fp16-fix: "SDXL-VAE generates NaNs in fp16 because the
# internal activation values are too big"
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image = sdxl.lda.latent_to_image(x.to(dtype=torch.float32))

ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)

Expand Down Expand Up @@ -1426,7 +1426,7 @@ def test_diffusion_ip_adapter_controlnet(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)

Expand Down Expand Up @@ -1467,7 +1467,7 @@ def test_diffusion_ip_adapter_plus(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1514,7 +1514,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
condition_scale=5,
)
sdxl.lda.to(dtype=torch.float32)
predicted_image = sdxl.lda.decode_latents(x.to(dtype=torch.float32))
predicted_image = sdxl.lda.latent_to_image(x.to(dtype=torch.float32))

ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)

Expand Down Expand Up @@ -1548,7 +1548,7 @@ def test_sdxl_random_init(
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
predicted_image = sdxl.lda.latent_to_image(x=x)

ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1583,7 +1583,7 @@ def test_sdxl_random_init_sag(
time_ids=time_ids,
condition_scale=5,
)
predicted_image = sdxl.lda.decode_latents(x=x)
predicted_image = sdxl.lda.latent_to_image(x=x)

ensure_similar_images(img_1=predicted_image, img_2=expected_image)

Expand Down Expand Up @@ -1615,7 +1615,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
step=step,
targets=[target_1, target_2],
)
result = sd.lda.decode_latents(x=x)
result = sd.lda.latent_to_image(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)


Expand Down Expand Up @@ -1654,7 +1654,7 @@ def test_t2i_adapter_depth(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image)

Expand Down Expand Up @@ -1702,7 +1702,7 @@ def test_t2i_adapter_xl_canny(
time_ids=time_ids,
condition_scale=7.5,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image)

Expand Down Expand Up @@ -1741,7 +1741,7 @@ def test_restart(
condition_scale=8,
)

predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)

Expand Down Expand Up @@ -1773,7 +1773,7 @@ def test_freeu(
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
predicted_image = sd15.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_freeu)

Expand Down Expand Up @@ -1829,6 +1829,6 @@ def test_hello_world(
pooled_text_embedding=pooled_text_embedding,
time_ids=time_ids,
)
predicted_image = sdxl.lda.decode_latents(x)
predicted_image = sdxl.lda.latent_to_image(x)

ensure_similar_images(predicted_image, expected_image)
Loading

0 comments on commit aaa9adc

Please sign in to comment.