-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
modify some foundational tests to also test in float16 and bfloat16
- Loading branch information
1 parent
b20474f
commit f3d2b6c
Showing
6 changed files
with
170 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
from PIL import Image | ||
|
||
from refiners.fluxion.utils import manual_seed, no_grad | ||
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL | ||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel | ||
|
||
|
||
@no_grad() | ||
def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: | ||
manual_seed(2) | ||
latents_0 = LatentDiffusionModel.sample_noise( | ||
size=(1, 4, 64, 64), | ||
device=test_device, | ||
dtype=test_dtype_fp32_bf16_fp16, | ||
) | ||
manual_seed(2) | ||
latents_1 = LatentDiffusionModel.sample_noise( | ||
size=(1, 4, 64, 64), | ||
offset_noise=0.0, # should be no-op | ||
device=test_device, | ||
dtype=test_dtype_fp32_bf16_fp16, | ||
) | ||
|
||
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) | ||
|
||
|
||
@no_grad() | ||
def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: | ||
sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
|
||
# prepare inputs | ||
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
text_embedding = sd.compute_clip_text_embedding("") | ||
|
||
# run the pipeline of models, for a single step | ||
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) | ||
|
||
assert output.shape == (1, 4, 64, 64) | ||
|
||
|
||
@no_grad() | ||
def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: | ||
sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
|
||
# prepare inputs | ||
latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
target_image = Image.new("RGB", (512, 512)) | ||
mask = Image.new("L", (512, 512)) | ||
sd.set_inpainting_conditions(target_image=target_image, mask=mask) | ||
text_embedding = sd.compute_clip_text_embedding("") | ||
|
||
# run the pipeline of models, for a single step | ||
output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) | ||
|
||
assert output.shape == (1, 4, 64, 64) | ||
|
||
|
||
@no_grad() | ||
def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: | ||
sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
|
||
# prepare inputs | ||
latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16) | ||
text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("") | ||
time_ids = sd.default_time_ids | ||
|
||
# run the pipeline of models, for a single step | ||
output = sd( | ||
latent_noise, | ||
step=0, | ||
clip_text_embedding=text_embedding, | ||
pooled_text_embedding=pooled_text_embedding, | ||
time_ids=time_ids, | ||
) | ||
|
||
assert output.shape == (1, 4, 128, 128) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters