From f3d2b6c3250a3151a321d5a8e8974223dd8a6050 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 3 Oct 2024 08:47:37 +0000 Subject: [PATCH] modify some foundational tests to also test in float16 and bfloat16 --- .../foundationals/clip/test_image_encoder.py | 35 ++++++--- tests/foundationals/clip/test_text_encoder.py | 31 +++++--- tests/foundationals/dinov2/test_dinov2.py | 15 ++-- .../latent_diffusion/test_auto_encoder.py | 45 ++++++++--- .../latent_diffusion/test_models.py | 77 +++++++++++++++++++ .../latent_diffusion/test_sd15_unet.py | 12 ++- 6 files changed, 170 insertions(+), 45 deletions(-) create mode 100644 tests/foundationals/latent_diffusion/test_models.py diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index ff990bda5..69988c208 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -10,12 +10,16 @@ @pytest.fixture(scope="module") -def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPImageEncoderH: +def our_encoder( + test_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> CLIPImageEncoderH: weights = test_weights_path / "CLIPImageEncoderH.safetensors" if not weights.is_file(): warn(f"could not find weights at {weights}, skipping") pytest.skip(allow_module_level=True) - encoder = CLIPImageEncoderH(device=test_device) + encoder = CLIPImageEncoderH(device=test_device, dtype=test_dtype_fp32_bf16_fp16) tensors = load_from_safetensors(weights) encoder.load_state_dict(tensors) return encoder @@ -31,24 +35,31 @@ def stabilityai_unclip_weights_path(test_weights_path: Path): @pytest.fixture(scope="module") -def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection: - return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore - test_device # type: ignore - ) +def ref_encoder( + stabilityai_unclip_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> CLIPVisionModelWithProjection: + return CLIPVisionModelWithProjection.from_pretrained( # type: ignore + stabilityai_unclip_weights_path, + subfolder="image_encoder", + ).to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) +@no_grad() +@pytest.mark.flaky(reruns=3) def test_encoder( ref_encoder: CLIPVisionModelWithProjection, our_encoder: CLIPImageEncoderH, - test_device: torch.device, ): - x = torch.randn(1, 3, 224, 224).to(test_device) + assert ref_encoder.dtype == our_encoder.dtype + assert ref_encoder.device == our_encoder.device + x = torch.randn((1, 3, 224, 224), dtype=ref_encoder.dtype, device=ref_encoder.device) - with no_grad(): - ref_embeddings = ref_encoder(x).image_embeds - our_embeddings = our_encoder(x) + ref_embeddings = ref_encoder(x).image_embeds + our_embeddings = our_encoder(x) assert ref_embeddings.shape == (1, 1024) assert our_embeddings.shape == (1, 1024) - assert (our_embeddings - ref_embeddings).abs().max() < 0.01 + assert torch.allclose(our_embeddings, ref_embeddings, atol=0.05) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 60a873bb2..28eeada65 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -30,13 +30,17 @@ @pytest.fixture(scope="module") -def our_encoder(test_weights_path: Path, test_device: torch.device) -> CLIPTextEncoderL: +def our_encoder( + test_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_fp16: torch.dtype, +) -> CLIPTextEncoderL: weights = test_weights_path / "CLIPTextEncoderL.safetensors" if not weights.is_file(): warn(f"could not find weights at {weights}, skipping") pytest.skip(allow_module_level=True) - encoder = CLIPTextEncoderL(device=test_device) tensors = load_from_safetensors(weights) + encoder = CLIPTextEncoderL(device=test_device, dtype=test_dtype_fp32_fp16) encoder.load_state_dict(tensors) return encoder @@ -56,8 +60,15 @@ def ref_tokenizer(runwayml_weights_path: Path) -> transformers.CLIPTokenizer: @pytest.fixture(scope="module") -def ref_encoder(runwayml_weights_path: Path, test_device: torch.device) -> transformers.CLIPTextModel: - return transformers.CLIPTextModel.from_pretrained(runwayml_weights_path, subfolder="text_encoder").to(test_device) # type: ignore +def ref_encoder( + runwayml_weights_path: Path, + test_device: torch.device, + test_dtype_fp32_fp16: torch.dtype, +) -> transformers.CLIPTextModel: + return transformers.CLIPTextModel.from_pretrained( # type: ignore + runwayml_weights_path, + subfolder="text_encoder", + ).to(device=test_device, dtype=test_dtype_fp32_fp16) # type: ignore def test_basics(ref_tokenizer: transformers.CLIPTokenizer, our_encoder: CLIPTextEncoderL): @@ -70,12 +81,12 @@ def prompt(request: pytest.FixtureRequest): return long_prompt if request.param == "" else request.param +@no_grad() def test_encoder( prompt: str, ref_tokenizer: transformers.CLIPTokenizer, ref_encoder: transformers.CLIPTextModel, our_encoder: CLIPTextEncoderL, - test_device: torch.device, ): ref_tokens = ref_tokenizer( # type: ignore prompt, @@ -89,18 +100,16 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with no_grad(): - ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] - our_embeddings = our_encoder(prompt) + ref_embeddings = ref_encoder(ref_tokens.to(device=ref_encoder.device))[0] + our_embeddings = our_encoder(prompt) assert ref_embeddings.shape == (1, 77, 768) assert our_embeddings.shape == (1, 77, 768) # FG-336 - Not strictly equal because we do not use the same implementation # of self-attention. We use `scaled_dot_product_attention` which can have - # numerical differences depending on the backend. - # Also we use FP16 weights. - assert (our_embeddings - ref_embeddings).abs().max() < 0.01 + # numerical differences depending on the backend. Also we use FP16 weights. + torch.testing.assert_close(our_embeddings, ref_embeddings, atol=0.035, rtol=0.0) def test_list_string_tokenizer( diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 47f7959b0..1e73b6dff 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -109,7 +109,7 @@ def test_dinov2_facebook_weights( ) -> None: manual_seed(2) input_data = torch.randn( - (1, 3, resolution, resolution), + size=(1, 3, resolution, resolution), device=test_device, ) @@ -129,27 +129,28 @@ def test_dinov2_facebook_weights( @no_grad() -def test_dinov2_float16( +def test_dinov2( resolution: int, + test_dtype_fp32_bf16_fp16: torch.dtype, test_device: torch.device, ) -> None: if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() - model = DINOv2_small(device=test_device, dtype=torch.float16) + model = DINOv2_small(device=test_device, dtype=test_dtype_fp32_bf16_fp16) manual_seed(2) input_data = torch.randn( - (1, 3, resolution, resolution), + size=(1, 3, resolution, resolution), device=test_device, - dtype=torch.float16, + dtype=test_dtype_fp32_bf16_fp16, ) output = model(input_data) sequence_length = (resolution // model.patch_size) ** 2 + 1 assert output.shape == (1, sequence_length, model.embedding_dim) - assert output.dtype == torch.float16 + assert output.dtype == test_dtype_fp32_bf16_fp16 @no_grad() @@ -162,7 +163,7 @@ def test_dinov2_batch_size( batch_size = 4 manual_seed(2) input_data = torch.randn( - (batch_size, 3, resolution, resolution), + size=(batch_size, 3, resolution, resolution), device=test_device, ) diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index dc6d77cc9..eedd25112 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -6,8 +6,8 @@ from PIL import Image from tests.utils import ensure_similar_images -from refiners.fluxion.utils import load_from_safetensors, no_grad -from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder +from refiners.fluxion.utils import no_grad +from refiners.foundationals.latent_diffusion import LatentDiffusionAutoencoder, SD1Autoencoder, SDXLAutoencoder @pytest.fixture(scope="module") @@ -15,16 +15,37 @@ def ref_path() -> Path: return Path(__file__).parent / "test_auto_encoder_ref" -@pytest.fixture(scope="module") -def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: - lda_weights = test_weights_path / "lda.safetensors" - if not lda_weights.is_file(): - warn(f"could not find weights at {lda_weights}, skipping") - pytest.skip(allow_module_level=True) - encoder = LatentDiffusionAutoencoder(device=test_device) - tensors = load_from_safetensors(lda_weights) - encoder.load_state_dict(tensors) - return encoder +@pytest.fixture(scope="module", params=["SD1.5", "SDXL"]) +def lda( + request: pytest.FixtureRequest, + test_weights_path: Path, + test_dtype_fp32_bf16_fp16: torch.dtype, + test_device: torch.device, +) -> LatentDiffusionAutoencoder: + model_version = request.param + match (model_version, test_dtype_fp32_bf16_fp16): + case ("SD1.5", _): + weight_path = test_weights_path / "lda.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SD1Autoencoder().load_from_safetensors(weight_path) + case ("SDXL", torch.float16): + weight_path = test_weights_path / "sdxl-lda-fp16-fix.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SDXLAutoencoder().load_from_safetensors(weight_path) + case ("SDXL", _): + weight_path = test_weights_path / "sdxl-lda.safetensors" + if not weight_path.is_file(): + warn(f"could not find weights at {weight_path}, skipping") + pytest.skip(allow_module_level=True) + model = SDXLAutoencoder().load_from_safetensors(weight_path) + case _: + raise ValueError(f"Unknown model version: {model_version}") + model = model.to(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + return model @pytest.fixture(scope="module") diff --git a/tests/foundationals/latent_diffusion/test_models.py b/tests/foundationals/latent_diffusion/test_models.py new file mode 100644 index 000000000..5fd74c878 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_models.py @@ -0,0 +1,77 @@ +import torch +from PIL import Image + +from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting, StableDiffusion_XL +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel + + +@no_grad() +def test_sample_noise_zero_offset(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + manual_seed(2) + latents_0 = LatentDiffusionModel.sample_noise( + size=(1, 4, 64, 64), + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) + manual_seed(2) + latents_1 = LatentDiffusionModel.sample_noise( + size=(1, 4, 64, 64), + offset_noise=0.0, # should be no-op + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) + + assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) + + +@no_grad() +def test_sd15_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_1(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + text_embedding = sd.compute_clip_text_embedding("") + + # run the pipeline of models, for a single step + output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) + + assert output.shape == (1, 4, 64, 64) + + +@no_grad() +def test_sd15_inpainting_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_1_Inpainting(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 64, 64, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + target_image = Image.new("RGB", (512, 512)) + mask = Image.new("L", (512, 512)) + sd.set_inpainting_conditions(target_image=target_image, mask=mask) + text_embedding = sd.compute_clip_text_embedding("") + + # run the pipeline of models, for a single step + output = sd(latent_noise, step=0, clip_text_embedding=text_embedding) + + assert output.shape == (1, 4, 64, 64) + + +@no_grad() +def test_sdxl_one_step(test_device: torch.device, test_dtype_fp32_bf16_fp16: torch.dtype) -> None: + sd = StableDiffusion_XL(device=test_device, dtype=test_dtype_fp32_bf16_fp16) + + # prepare inputs + latent_noise = torch.randn(1, 4, 128, 128, device=test_device, dtype=test_dtype_fp32_bf16_fp16) + text_embedding, pooled_text_embedding = sd.compute_clip_text_embedding("") + time_ids = sd.default_time_ids + + # run the pipeline of models, for a single step + output = sd( + latent_noise, + step=0, + clip_text_embedding=text_embedding, + pooled_text_embedding=pooled_text_embedding, + time_ids=time_ids, + ) + + assert output.shape == (1, 4, 128, 128) diff --git a/tests/foundationals/latent_diffusion/test_sd15_unet.py b/tests/foundationals/latent_diffusion/test_sd15_unet.py index 3ecf1ae4e..6c01a5465 100644 --- a/tests/foundationals/latent_diffusion/test_sd15_unet.py +++ b/tests/foundationals/latent_diffusion/test_sd15_unet.py @@ -7,9 +7,15 @@ @pytest.fixture(scope="module") -def refiners_sd15_unet(test_device: torch.device) -> SD1UNet: - unet = SD1UNet(in_channels=4, device=test_device) - return unet +def refiners_sd15_unet( + test_device: torch.device, + test_dtype_fp32_bf16_fp16: torch.dtype, +) -> SD1UNet: + return SD1UNet( + in_channels=4, + device=test_device, + dtype=test_dtype_fp32_bf16_fp16, + ) def test_unet_context_flush(refiners_sd15_unet: SD1UNet):