From d0414f8aaafa308765b859997a70849a6e93aa3b Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 15 Feb 2024 18:54:31 +0100 Subject: [PATCH] prepare configs --- configs/finetune-color-palette-piercus.toml | 78 ---------------- configs/finetune-color-palette.toml | 75 ---------------- configs/finetune-histogram-piercus.toml | 76 ---------------- ...-autoencoder-4_16-32-b8-lr5e-3-kl_div.toml | 64 -------------- ...toencoder-5b-2_4-8x4-b8-lr5e-3-kl_div.toml | 64 -------------- .../finetune-color-palette-mlp.toml} | 44 +++++----- ...ain-histogram-autoencoder-ckpt-reload.toml | 2 +- .../finetune-histogram-jiont-training.toml | 0 .../finetune-histogram-jiont-training.toml | 0 ...inetune-color-palette-mlp-unweighted.toml} | 44 ++++++---- .../finetune-color-palette-mlp-weighted.toml} | 35 ++++---- .../finetune-histogram-xs-2-layer.toml | 88 ------------------- scripts/training/train-color-palette.py | 63 +++++++++++++ .../fluxion/adapters/color_palette.py | 1 - src/refiners/fluxion/adapters/histogram.py | 16 +++- .../training_utils/metrics/color_palette.py | 7 +- .../trainers/abstract_color_trainer.py | 13 ++- .../training_utils/trainers/color_palette.py | 63 +++++++------ .../training_utils/trainers/histogram.py | 48 +--------- 19 files changed, 206 insertions(+), 575 deletions(-) delete mode 100644 configs/finetune-color-palette-piercus.toml delete mode 100644 configs/finetune-color-palette.toml delete mode 100644 configs/finetune-histogram-piercus.toml delete mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-4_16-32-b8-lr5e-3-kl_div.toml delete mode 100644 configs/histogram-auto-encoder/train-histogram-autoencoder-5b-2_4-8x4-b8-lr5e-3-kl_div.toml rename configs/{scheduled-local/finetune-color-palette-mlp-transforms.toml => local/experiments/finetune-color-palette-mlp.toml} (71%) rename configs/{histogram-auto-encoder => local/pending}/train-histogram-autoencoder-ckpt-reload.toml (93%) create mode 100644 configs/local/scheduled/finetune-histogram-jiont-training.toml rename configs/{scheduled-remote => remote/pending}/finetune-histogram-jiont-training.toml (100%) rename configs/{scheduled-local/finetune-color-palette-mlp-noweight.toml => remote/scheduled/finetune-color-palette-mlp-unweighted.toml} (60%) rename configs/{scheduled-local/finetune-color-palette-mlp.toml => remote/scheduled/finetune-color-palette-mlp-weighted.toml} (72%) delete mode 100644 configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml create mode 100644 scripts/training/train-color-palette.py diff --git a/configs/finetune-color-palette-piercus.toml b/configs/finetune-color-palette-piercus.toml deleted file mode 100644 index 332c953c7..000000000 --- a/configs/finetune-color-palette-piercus.toml +++ /dev/null @@ -1,78 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now -[wandb] -mode = "online" # "online", "offline", "disabled" -entity = "piercus" -project = "color-palette" -name = "base-l4" -tags = ["remote", "l4"] - -[models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false, gpu_index = 0} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false, gpu_index = 0} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false, gpu_index = 0} -color_palette_encoder = {train = true, gpu_index = 0} - -[adapters] -color_palette_adapter = {} - -[latent_diffusion] -unconditional_sampling_probability = 0.1 -offset_noise = 0.2 - -[color_palette] -max_colors = 8 -feedforward_dim = 20 -num_layers = 2 -num_attention_heads = 2 -embedding_dim = 10 - -[training] -duration = "5:epoch" -seed = 0 -gpu_index = 0 -batch_size = 10 -gradient_accumulation = "1:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "250:step" -evaluation_seed = 1 -num_workers = 8 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 5e-4 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "250:step" - -[dropout] -dropout_probability = 0.2 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train" - -[checkpointing] -save_interval = "1:epoch" -use_wandb = true - -[test_color_palette] -num_inference_steps = 30 -use_short_prompts = false -prompts = [ - {"text" = "a cute cat", "color_palette" = [[0,0,255]], seed=1}, - {"text" = "a cute cat", "color_palette" = [[255,0,0]], seed=1}, - {"text" = "a cute cat", "color_palette" = [[0,0,255], [255,255,255], [255,0,0]]}, - {"text" = "a cute cat", "color_palette" = [[255,0,0], [255,255,255], [0,0,255]]} -] -num_palette_sample = 8 -condition_scale = 7.5 diff --git a/configs/finetune-color-palette.toml b/configs/finetune-color-palette.toml deleted file mode 100644 index f06329610..000000000 --- a/configs/finetune-color-palette.toml +++ /dev/null @@ -1,75 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now - -[wandb] -mode = "offline" # "online", "offline", "disabled" -entity = "acme" -project = "color-palette-training" - -[models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false} -color_palette_encoder = {train = true} - -[latent_diffusion] -unconditional_sampling_probability = 0.1 -offset_noise = 0.1 - -[color_palette] -max_colors = 8 -feedforward_dim = 512 -feedforward_dim = 512 -num_attention_heads = 6 -num_layers = 3 - -[training] -duration = "1000:epoch" -seed = 0 -gpu_index = 0 -batch_size = 2 -gradient_accumulation = "4:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "5:epoch" -evaluation_seed = 1 - -[optimizer] -optimizer = "Prodigy" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 1 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "500:step" - - -[dropout] -dropout_probability = 0.0 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -local_folder = "data/1aurent/unsplash-lite-palette" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train" - -[checkpointing] -# save_folder = "/path/to/ckpts" -save_interval = "1:epoch" - -[test_color_palette] -num_inference_steps = 30 -use_short_prompts = false -prompts = [ - {"text" = "a cute cat", "color_palette" = [[0,0,255]]}, - {"text" = "a cute cat", "color_palette" = [[255,0,0]]}, - {"text" = "a cute cat", "color_palette" = [[0,0,255], [255,255,255], [255,0,0]]}, - {"text" = "a cute cat", "color_palette" = [[255,0,0], [255,255,255], [0,0,255]]} -] -num_palette_sample = 8 -condition_scale = 7.5 \ No newline at end of file diff --git a/configs/finetune-histogram-piercus.toml b/configs/finetune-histogram-piercus.toml deleted file mode 100644 index 5a2141487..000000000 --- a/configs/finetune-histogram-piercus.toml +++ /dev/null @@ -1,76 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now -[wandb] -mode = "online" # "online", "offline", "disabled" -entity = "piercus" -project = "histogram" -name="frozen-histo-autoencode" - -[models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index=0, train = false} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index=1, train = false} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index=1, train = false} -histogram_auto_encoder = {gpu_index=0, train = false, checkpoint = "tmp/step1900.safetensors"} - -[latent_diffusion] -unconditional_sampling_probability = 0.1 -offset_noise = 0.2 - -[histogram_auto_encoder] -latent_dim = 8 -resnet_sizes = [4, 4, 4, 4, 4, 4] -n_down_samples = 5 -color_bits = 6 - -[training] -duration = "1:epoch" -seed = 0 -gpu_index = 0 -batch_size = 1 -gradient_accumulation = "1:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "1000:step" -evaluation_seed = 1 -color_loss_weight = 0.2 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 1e-5 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "100:step" - -[dropout] -dropout_probability = 0.0 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train" -#random_crop = false - -[checkpointing] -#save_folder = "ckpts" -save_interval = "10000:step" - -[test_histogram] -num_inference_steps = 30 -use_short_prompts = false -histogram_db_indexes = [0, 1, 2, 3] -prompts = [ - "A Bustling City Street", - "A cute cat", - "An oil painting", - "A photography of a beautiful woman", - "A pair of shoes", - "A group of working people" -] -condition_scale = 1.0 # deactivate cause negative sampling not well defined diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-4_16-32-b8-lr5e-3-kl_div.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-4_16-32-b8-lr5e-3-kl_div.toml deleted file mode 100644 index 546502361..000000000 --- a/configs/histogram-auto-encoder/train-histogram-autoencoder-4_16-32-b8-lr5e-3-kl_div.toml +++ /dev/null @@ -1,64 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now -[wandb] -mode = "online" # "online", "offline", "disabled" -entity = "piercus" -project = "histo-autoencoder" -name = "4b-4_8-cube2-b8-lr5e-3-logits" -tags = ["autoencoder", "4bits"] - -[histogram_auto_encoder] -latent_dim = 4 -resnet_sizes = [32, 32, 32, 32] -n_down_samples = 3 -color_bits = 4 -num_groups = 1 -loss = "kl_div" - -[models] -histogram_auto_encoder = {train = true} - -[training] -duration = "30:epoch" -seed = 0 -gpu_index = 1 -batch_size = 8 -gradient_accumulation = "1:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "50:step" -evaluation_seed = 1 -num_workers = 4 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 5e-3 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "50:step" - -[dropout] -dropout_probability = 0.2 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[10:]" - -[eval_dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[:10]" - -[checkpointing] -save_interval = "1000:step" -use_wandb = true diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-5b-2_4-8x4-b8-lr5e-3-kl_div.toml b/configs/histogram-auto-encoder/train-histogram-autoencoder-5b-2_4-8x4-b8-lr5e-3-kl_div.toml deleted file mode 100644 index 5da0d7fd7..000000000 --- a/configs/histogram-auto-encoder/train-histogram-autoencoder-5b-2_4-8x4-b8-lr5e-3-kl_div.toml +++ /dev/null @@ -1,64 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now -[wandb] -mode = "online" # "online", "offline", "disabled" -entity = "piercus" -project = "histo-autoencoder" -name = "5b-4_8-cube2-b8-lr5e-3-logits" -tags = ["autoencoder", "6bits", "64x4-emb"] - -[histogram_auto_encoder] -latent_dim = 4 -resnet_sizes = [4, 4, 8, 8, 16] -n_down_samples = 4 -color_bits = 6 -num_groups = 1 -loss = "kl_div" - -[models] -histogram_auto_encoder = {train = true} - -[training] -duration = "30:epoch" -seed = 0 -gpu_index = 1 -batch_size = 8 -gradient_accumulation = "1:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "50:step" -evaluation_seed = 1 -num_workers = 4 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 5e-3 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "50:step" - -[dropout] -dropout_probability = 0.2 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[10:]" - -[eval_dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[:10]" - -[checkpointing] -save_interval = "1000:step" -use_wandb = true diff --git a/configs/scheduled-local/finetune-color-palette-mlp-transforms.toml b/configs/local/experiments/finetune-color-palette-mlp.toml similarity index 71% rename from configs/scheduled-local/finetune-color-palette-mlp-transforms.toml rename to configs/local/experiments/finetune-color-palette-mlp.toml index 633e9e028..7f84594f4 100644 --- a/configs/scheduled-local/finetune-color-palette-mlp-transforms.toml +++ b/configs/local/experiments/finetune-color-palette-mlp.toml @@ -3,14 +3,14 @@ script = "finetune-ldm-color-palette.py" # not used for now mode = "online" # "online", "offline", "disabled" entity = "piercus" project = "color-palette" -name = "mlp-weighted-hue" -tags = ["local", "reload-ckpt", "grid-eval", "weighted"] +name = "weighted-palette" +tags = ["reload-ckpt", "weighted-palette"] [models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index=1, train = false} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index=0, train = false} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index=0, train = false} -color_palette_encoder = {gpu_index=1, train = true} +unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false} +text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false} +lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false} +color_palette_encoder = {train = true} [adapters] color_palette_adapter = {checkpoint = "./tmp/converted-mlp-step4998.safetensors"} @@ -25,17 +25,19 @@ feedforward_dim = 20 num_layers = 2 mode = 'mlp' embedding_dim = 10 -weighted_palette = true +weighted_palette = false [training] -duration = "2:epoch" +duration = "5:epoch" seed = 0 -batch_size = 1 +batch_size = 10 gradient_accumulation = "1:step" clip_grad_norm = 1.0 # clip_grad_value = 1.0 evaluation_interval = "1000:step" evaluation_seed = 1 +num_workers = 8 +use_color_loss = false [optimizer] optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" @@ -58,29 +60,31 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train[4:]" -color_jitter = {hue=0.5} +split = "train[200:]" + +[checkpointing] +save_interval = "5000:step" +use_wandb = true [eval_dataset] hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train[0:4]" - -[checkpointing] -save_interval = "10000:step" -use_wandb = true +split = "train[:20]" [evaluation] -batch_size = 3 +batch_size = 4 num_inference_steps = 30 +color_bits = 4 use_short_prompts = false -db_indexes = [0, 1, 2, 3] +db_indexes = [0, 1]#, 2, 3, 4, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] prompts = [ "A Bustling City Street", "A cute cat", - "An oil painting", - "A photography of a beautiful woman" + #"An oil painting", + "A photography of a beautiful woman", + # "A pair of shoes", + # "A group of working people" ] condition_scale = 7.5 diff --git a/configs/histogram-auto-encoder/train-histogram-autoencoder-ckpt-reload.toml b/configs/local/pending/train-histogram-autoencoder-ckpt-reload.toml similarity index 93% rename from configs/histogram-auto-encoder/train-histogram-autoencoder-ckpt-reload.toml rename to configs/local/pending/train-histogram-autoencoder-ckpt-reload.toml index 821823203..8793bfbd4 100644 --- a/configs/histogram-auto-encoder/train-histogram-autoencoder-ckpt-reload.toml +++ b/configs/local/pending/train-histogram-autoencoder-ckpt-reload.toml @@ -15,7 +15,7 @@ num_groups = 4 loss = "kl_div" [models] -histogram_auto_encoder = {train = true, checkpoint = "tmp/step1900.safetensors"} +histogram_auto_encoder = {train = true, checkpoint = "tmp/ckpt-reload-step6000.safetensors"} [training] duration = "10:epoch" diff --git a/configs/local/scheduled/finetune-histogram-jiont-training.toml b/configs/local/scheduled/finetune-histogram-jiont-training.toml new file mode 100644 index 000000000..e69de29bb diff --git a/configs/scheduled-remote/finetune-histogram-jiont-training.toml b/configs/remote/pending/finetune-histogram-jiont-training.toml similarity index 100% rename from configs/scheduled-remote/finetune-histogram-jiont-training.toml rename to configs/remote/pending/finetune-histogram-jiont-training.toml diff --git a/configs/scheduled-local/finetune-color-palette-mlp-noweight.toml b/configs/remote/scheduled/finetune-color-palette-mlp-unweighted.toml similarity index 60% rename from configs/scheduled-local/finetune-color-palette-mlp-noweight.toml rename to configs/remote/scheduled/finetune-color-palette-mlp-unweighted.toml index 0b61d88ee..f9c138386 100644 --- a/configs/scheduled-local/finetune-color-palette-mlp-noweight.toml +++ b/configs/remote/scheduled/finetune-color-palette-mlp-unweighted.toml @@ -1,19 +1,19 @@ script = "finetune-ldm-color-palette.py" # not used for now [wandb] -mode = "online" # "online", "offline", "disabled" +mode = "offline" # "online", "offline", "disabled" entity = "piercus" project = "color-palette" -name = "mlp-not-weighted" -tags = ["local", "reload-ckpt", "grid-eval", "not-weighted"] +name = "weighted-palette" +tags = ["reload-ckpt", "unweighted-palette"] [models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index=1, train = false} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index=0, train = false} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index=0, train = false} -color_palette_encoder = {gpu_index=1, train = true} +unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false} +text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false} +lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false} +color_palette_encoder = {train = true} [adapters] -color_palette_adapter = {checkpoint = "./tmp/mlp-step4998.safetensors"} +color_palette_adapter = {checkpoint = "./tmp/converted-mlp-step4998.safetensors"} [latent_diffusion] unconditional_sampling_probability = 0.1 @@ -28,14 +28,16 @@ embedding_dim = 10 weighted_palette = false [training] -duration = "2000:step" +duration = "5:epoch" seed = 0 -batch_size = 1 +batch_size = 10 gradient_accumulation = "1:step" clip_grad_norm = 1.0 # clip_grad_value = 1.0 evaluation_interval = "1000:step" evaluation_seed = 1 +num_workers = 8 +use_color_loss = false [optimizer] optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" @@ -58,21 +60,31 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train" +split = "train[200:]" [checkpointing] -save_interval = "2000:step" +save_interval = "5000:step" use_wandb = true +[eval_dataset] +hf_repo = "1aurent/unsplash-lite-palette" +revision = "main" +resize_image_max_size = 512 +caption_key = "ai_description" +split = "train[:20]" + [evaluation] -batch_size = 3 +batch_size = 4 num_inference_steps = 30 +color_bits = 4 use_short_prompts = false -histogram_db_indexes = [0, 1, 2, 3] +db_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] prompts = [ "A Bustling City Street", "A cute cat", - "An oil painting", - "A photography of a beautiful woman" + #"An oil painting", + "A photography of a beautiful woman", + # "A pair of shoes", + # "A group of working people" ] condition_scale = 7.5 diff --git a/configs/scheduled-local/finetune-color-palette-mlp.toml b/configs/remote/scheduled/finetune-color-palette-mlp-weighted.toml similarity index 72% rename from configs/scheduled-local/finetune-color-palette-mlp.toml rename to configs/remote/scheduled/finetune-color-palette-mlp-weighted.toml index 98bbf0eba..6b9ef1151 100644 --- a/configs/scheduled-local/finetune-color-palette-mlp.toml +++ b/configs/remote/scheduled/finetune-color-palette-mlp-weighted.toml @@ -3,14 +3,14 @@ script = "finetune-ldm-color-palette.py" # not used for now mode = "offline" # "online", "offline", "disabled" entity = "piercus" project = "color-palette" -name = "mlp-weighted" -tags = ["local", "reload-ckpt", "grid-eval", "weighted"] +name = "weighted-palette" +tags = ["reload-ckpt", "weighted-palette"] [models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index=1, train = false} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index=0, train = false} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index=0, train = false} -color_palette_encoder = {gpu_index=1, train = true} +unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false} +text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false} +lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false} +color_palette_encoder = {train = true} [adapters] color_palette_adapter = {checkpoint = "./tmp/converted-mlp-step4998.safetensors"} @@ -28,14 +28,16 @@ embedding_dim = 10 weighted_palette = true [training] -duration = "2000:step" +duration = "5:epoch" seed = 0 -batch_size = 1 +batch_size = 10 gradient_accumulation = "1:step" clip_grad_norm = 1.0 # clip_grad_value = 1.0 evaluation_interval = "1000:step" evaluation_seed = 1 +num_workers = 8 +use_color_loss = false [optimizer] optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" @@ -58,10 +60,10 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train" +split = "train[200:]" [checkpointing] -save_interval = "10000:step" +save_interval = "5000:step" use_wandb = true [eval_dataset] @@ -69,17 +71,20 @@ hf_repo = "1aurent/unsplash-lite-palette" revision = "main" resize_image_max_size = 512 caption_key = "ai_description" -split = "train[0:4]" +split = "train[:20]" [evaluation] -batch_size = 3 +batch_size = 4 num_inference_steps = 30 +color_bits = 4 use_short_prompts = false -db_indexes = [0, 1, 2, 3] +db_indexes = [0, 1, 2, 3, 4, 5, 6, 7, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] prompts = [ "A Bustling City Street", "A cute cat", - "An oil painting", - "A photography of a beautiful woman" + #"An oil painting", + "A photography of a beautiful woman", + # "A pair of shoes", + # "A group of working people" ] condition_scale = 7.5 diff --git a/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml b/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml deleted file mode 100644 index e169ac199..000000000 --- a/configs/scheduled-local-histogram/finetune-histogram-xs-2-layer.toml +++ /dev/null @@ -1,88 +0,0 @@ -script = "finetune-ldm-color-palette.py" # not used for now -[wandb] -mode = "offline" # "online", "offline", "disabled" -entity = "piercus" -project = "histogram" -name="local-histo-palette-eval" -tags = ["l4", "local"] - -[models] -unet = {gpu_index = 0, checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", train = false} -text_encoder = {gpu_index = 1,checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = false} -lda = {gpu_index = 1, checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = false} -histogram_auto_encoder = {gpu_index = 0, train = true, learning_rate = 1e-2} -histogram_projection = {gpu_index = 0, train = true, learning_rate = 1e-2} - -[latent_diffusion] -unconditional_sampling_probability = 0.1 -offset_noise = 0.1 - -[histogram_auto_encoder] -latent_dim = 8 -resnet_sizes = [4, 4, 4, 4, 4, 4] -n_down_samples = 5 -color_bits = 6 - -[training] -duration = "100:step" -seed = 0 -batch_size = 1 -gradient_accumulation = "1:step" -clip_grad_norm = 1.0 -# clip_grad_value = 1.0 -evaluation_interval = "2000:step" -evaluation_seed = 1 -color_loss_weight = 0.1 - -[optimizer] -optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" -learning_rate = 1e-4 -betas = [0.9, 0.999] -eps = 1e-8 -weight_decay = 1e-2 - -[adapter] -embedding_dim = 10 -num_tokens = 1 -num_layers = 2 -feedforward_dim = 20 - -[scheduler] -scheduler_type = "ConstantLR" -update_interval = "1:step" -warmup = "500:step" - -[dropout] -dropout_probability = 0.1 -use_gyro_dropout = false - -[dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[200:]" -#random_crop = false - -[eval_dataset] -hf_repo = "1aurent/unsplash-lite-palette" -revision = "main" -resize_image_max_size = 512 -caption_key = "ai_description" -split = "train[0:200]" - -[checkpointing] -#save_folder = "ckpts" -save_interval = "1:epoch" -use_wandb = true - -[evaluation] -num_inference_steps = 30 -db_indexes = [15] -batch_size = 4 -prompts = [ -# "A Bustling City Street", - "A cute cat", -# "A photography of a beautiful woman" -] -condition_scale = 7.5 \ No newline at end of file diff --git a/scripts/training/train-color-palette.py b/scripts/training/train-color-palette.py new file mode 100644 index 000000000..c0d81b573 --- /dev/null +++ b/scripts/training/train-color-palette.py @@ -0,0 +1,63 @@ +from matplotlib.pyplot import hist +from refiners.training_utils.trainers.color_palette import ColorPaletteLatentDiffusionConfig, ColorPaletteLatentDiffusionTrainer +from refiners.training_utils.trainers.histogram import HistogramLatentDiffusionConfig, HistogramLatentDiffusionTrainer +from refiners.training_utils.trainers.histogram_auto_encoder import HistogramAutoEncoderTrainer, TrainHistogramAutoEncoderConfig +from refiners.training_utils.trainers.latent_diffusion import FinetuneLatentDiffusionBaseConfig + + +def adapt_config(config_path: str, config: FinetuneLatentDiffusionBaseConfig) -> FinetuneLatentDiffusionBaseConfig: + if 'local' in config_path: + config.training.batch_size = 1 + config.training.num_workers = min(4,config.training.num_workers) + config.wandb.tags = config.wandb.tags + ['local'] + + if 'unet' in config.models: + config.models['unet'].gpu_index = 1 + + if 'text_encoder' in config.models: + config.models['text_encoder'].gpu_index = 0 + + if 'lda' in config.models: + config.models['lda'].gpu_index = 0 + + if 'color_palette_encoder' in config.models: + config.models['color_palette_encoder'].gpu_index = 1 + + return config + else: + config.wandb.tags = config.wandb.tags + ['remote'] + return config + +def train_histogram_auto_encoder(config_path: str) -> None: + config = TrainHistogramAutoEncoderConfig.load_from_toml( + toml_path=config_path, + ) + trainer = HistogramAutoEncoderTrainer(config=adapt_config(config_path, config)) + trainer.train() + +def train_color_palette(config_path: str) -> None: + config = ColorPaletteLatentDiffusionConfig.load_from_toml( + toml_path=config_path, + ) + trainer = ColorPaletteLatentDiffusionTrainer(config=adapt_config(config_path, config)) + trainer.train() + +def train_histogram(config_path: str) -> None: + config = HistogramLatentDiffusionConfig.load_from_toml( + toml_path=config_path, + ) + trainer = HistogramLatentDiffusionTrainer(config=adapt_config(config_path, config)) + trainer.train() + +if __name__ == "__main__": + import sys + config_path = sys.argv[1] + + if 'histogram-autoencoder' in config_path: + train_histogram_auto_encoder(config_path) + elif 'color-palette' in config_path: + train_color_palette(config_path) + elif 'histogram' in config_path: + train_histogram(config_path) + else: + raise ValueError(f"Invalid config path: {config_path}") \ No newline at end of file diff --git a/src/refiners/fluxion/adapters/color_palette.py b/src/refiners/fluxion/adapters/color_palette.py index f5911253a..cba4cf63d 100644 --- a/src/refiners/fluxion/adapters/color_palette.py +++ b/src/refiners/fluxion/adapters/color_palette.py @@ -220,7 +220,6 @@ def __call__(self, image: Image.Image, size: int | None = None) -> ColorPalette: pixels = image_np.reshape(-1, 3) return self.from_pixels(pixels, size) def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalette: - print("pixels.shape", pixels.shape) kmeans = KMeans(n_clusters=size).fit(pixels) # type: ignore counts = np.unique(kmeans.labels_, return_counts=True)[1] # type: ignore palette : ColorPalette = [] diff --git a/src/refiners/fluxion/adapters/histogram.py b/src/refiners/fluxion/adapters/histogram.py index 949881de6..b2df3ba9c 100644 --- a/src/refiners/fluxion/adapters/histogram.py +++ b/src/refiners/fluxion/adapters/histogram.py @@ -175,10 +175,20 @@ def metrics_log(self, log: Tensor, y: Tensor) -> dict[str, Tensor]: "chi_square": self.chi_square(x, y), "intersection": self.intersection(x, y), "hellinger": self.hellinger(x, y), - "kl_div": self.kl_div(log, y) - # "emd": self.emd(x, y) + "kl_div": self.kl_div(log, y), + "emd": self.emd(x, y) + } + def metrics(self, x: Tensor, y: Tensor, eps: float = 1e-7) -> dict[str, Tensor]: + + return { + "mse": self.mse(x, y), + "correlation": self.correlation(x, y), + "chi_square": self.chi_square(x, y), + "intersection": self.intersection(x, y), + "hellinger": self.hellinger(x, y), + "kl_div": self.kl_div((x+eps).log(), y), + "emd": self.emd(x, y) } - class HistogramExtractor(fl.Chain): def __init__( self, diff --git a/src/refiners/training_utils/metrics/color_palette.py b/src/refiners/training_utils/metrics/color_palette.py index e93cf4dab..973805131 100644 --- a/src/refiners/training_utils/metrics/color_palette.py +++ b/src/refiners/training_utils/metrics/color_palette.py @@ -42,6 +42,7 @@ def __init__( def collate_fn(cls: Type[PromptType], batch: Sequence["AbstractColorPrompt"]) -> PromptType: opts : dict[str, CollatableProps] = {} for key in cls._list_keys: + opts[key] : list[Any] = [] for item in batch: @@ -50,6 +51,7 @@ def collate_fn(cls: Type[PromptType], batch: Sequence["AbstractColorPrompt"]) -> for prop in getattr(item, key): opts[key].append(prop) for key in cls._tensor_keys: + lst : list[Tensor] = [] for item in batch: if not hasattr(item, key): @@ -58,8 +60,9 @@ def collate_fn(cls: Type[PromptType], batch: Sequence["AbstractColorPrompt"]) -> if not isinstance(tensor, Tensor): raise ValueError(f"Key {key}, {tensor} should be a tensor") lst.append(tensor) - opts[key] = cat(lst) + opts[key] = cat(lst) + return cls(**opts) @classmethod @@ -126,7 +129,7 @@ class BatchHistogramPrompt(AbstractColorPrompt): } class BatchHistogramResults(AbstractColorResults[AbstractColorPrompt]): - _list_keys: List[str] = ["source_palettes", "source_prompts", "source_images", "db_indexes"] + _list_keys: List[str] = ["source_palettes", "source_prompts", "source_images", "db_indexes", "result_palettes"] _tensor_keys: dict[str, tuple[int, ...]] = { "source_histograms": (64, 64, 64), "text_embeddings": (77, 768), diff --git a/src/refiners/training_utils/trainers/abstract_color_trainer.py b/src/refiners/training_utils/trainers/abstract_color_trainer.py index 0665bd27a..6ab3deb24 100644 --- a/src/refiners/training_utils/trainers/abstract_color_trainer.py +++ b/src/refiners/training_utils/trainers/abstract_color_trainer.py @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader from refiners.fluxion.adapters.histogram import ( + HistogramDistance, HistogramExtractor ) from PIL import Image @@ -42,6 +43,7 @@ class ColorTrainerEvaluationConfig(TestDiffusionBaseConfig): db_indexes: list[int] batch_size: int = 1 + color_bits: int = 8 class ColorTrainerConfig(FinetuneLatentDiffusionBaseConfig): evaluation: ColorTrainerEvaluationConfig @@ -61,12 +63,15 @@ def __init__(self, db_indexes: list[int], hf_dataset: ColorPaletteDataset, sourc self.hf_dataset = hf_dataset self.source_prompts = source_prompts self.text_encoder = text_encoder - self.text_embeddings : list[Tensor] = [self.text_encoder(prompt) for prompt in source_prompts] + + txt_emb = [self.text_encoder(prompt).cpu() for prompt in source_prompts] + self.text_embeddings : list[Tensor] = txt_emb def __len__(self): return len(self.db_indexes) * len(self.source_prompts) def __getitem__(self, index: int) -> PromptType: + db_index = self.db_indexes[index // len(self.source_prompts)] source_prompt = self.source_prompts[index % len(self.source_prompts)] batch = self.hf_dataset[db_index] @@ -143,6 +148,10 @@ def draw_palette(self, palette: ColorPalette, width: int, height: int) -> Image. return palette_img + @cached_property + def histogram_distance(self) -> HistogramDistance: + return HistogramDistance(color_bits=self.config.evaluation.color_bits) + @scoped_seed(5) def compute_batch_evaluation(self, batch: PromptType, same_seed: bool = True) -> ResultType: batch_size = len(batch.source_prompts) @@ -240,7 +249,7 @@ def batch_metrics( @cached_property def histogram_extractor(self) -> HistogramExtractor: - return HistogramExtractor(color_bits=self.config.histogram_auto_encoder.color_bits) + return HistogramExtractor(color_bits=self.config.evaluation.color_bits) @cached_property diff --git a/src/refiners/training_utils/trainers/color_palette.py b/src/refiners/training_utils/trainers/color_palette.py index a26115b89..0bf15f1e2 100644 --- a/src/refiners/training_utils/trainers/color_palette.py +++ b/src/refiners/training_utils/trainers/color_palette.py @@ -6,7 +6,7 @@ from PIL import Image from pydantic import BaseModel from refiners.training_utils.trainers.abstract_color_trainer import AbstractColorTrainer, ColorTrainerEvaluationConfig -from refiners.training_utils.metrics.color_palette import BatchColorPalettePrompt +from refiners.training_utils.metrics.color_palette import BatchHistogramPrompt from refiners.training_utils.trainers.abstract_color_trainer import GridEvalDataset from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.training_utils.datasets.color_palette import ColorPaletteDataset @@ -20,12 +20,14 @@ SD1UNet ) from refiners.training_utils.callback import Callback -from refiners.training_utils.metrics.color_palette import BatchColorPalettePrompt, BatchColorPaletteResults, ImageAndPalette, batch_image_palette_metrics +from refiners.training_utils.metrics.color_palette import BatchHistogramPrompt, BatchHistogramResults, ImageAndPalette, batch_image_palette_metrics +from refiners.training_utils.trainers.histogram import GridEvalHistogramDataset from refiners.training_utils.trainers.latent_diffusion import ( FinetuneLatentDiffusionBaseConfig, ) from refiners.training_utils.datasets.color_palette import ColorDatasetConfig, ColorPalette, TextEmbeddingColorPaletteLatentsBatch from refiners.training_utils.callback import GradientNormLayerLogging +from refiners.training_utils.wandb import WandbLoggable class ColorPaletteConfig(BaseModel): feedforward_dim: int = 3072 @@ -53,22 +55,21 @@ class ColorPaletteLatentDiffusionConfig(FinetuneLatentDiffusionBaseConfig): dataset: ColorDatasetConfig eval_dataset: ColorDatasetConfig -class GridEvalPaletteDataset(GridEvalDataset[BatchColorPalettePrompt]): - __prompt_type__ = BatchColorPalettePrompt +class GridEvalPaletteDataset(GridEvalDataset[BatchHistogramPrompt]): + __prompt_type__ = BatchHistogramPrompt def __init__(self, db_indexes: list[int], hf_dataset: ColorPaletteDataset, source_prompts: list[str], text_encoder: CLIPTextEncoderL, color_palette_extractor: ColorPaletteExtractor): super().__init__(db_indexes, hf_dataset, source_prompts, text_encoder) self.color_palette_extractor = color_palette_extractor def process_item(self, items: TextEmbeddingColorPaletteLatentsBatch) -> dict[str, Any]: - if len(items) != 1: raise ValueError("The items must have length 1.") - + source_palettes = [self.color_palette_extractor(item.image, size=len(item.color_palette)) for item in items] return { "source_palettes": source_palettes } -class ColorPaletteLatentDiffusionTrainer(AbstractColorTrainer[BatchColorPalettePrompt, BatchColorPaletteResults, ColorPaletteLatentDiffusionConfig]): +class ColorPaletteLatentDiffusionTrainer(AbstractColorTrainer[BatchHistogramPrompt, BatchHistogramResults, ColorPaletteLatentDiffusionConfig]): @cached_property def color_palette_encoder(self) -> ColorPaletteEncoder: assert ( @@ -127,16 +128,17 @@ def load_models(self) -> dict[str, fl.Module]: } @cached_property - def grid_eval_dataset(self) -> GridEvalDataset[BatchColorPalettePrompt]: - return GridEvalPaletteDataset( + def grid_eval_dataset(self) -> GridEvalDataset[BatchHistogramPrompt]: + return GridEvalHistogramDataset( db_indexes=self.config.evaluation.db_indexes, hf_dataset=self.eval_dataset, source_prompts=self.config.evaluation.prompts, text_encoder=self.text_encoder, + histogram_extractor=self.histogram_extractor, color_palette_extractor=self.color_palette_extractor ) - - # def eval_dataset(self) -> list[BatchColorPalettePrompt]: + + # def eval_dataset(self) -> list[BatchHistogramPrompt]: # dataset = self.dataset # indices = self.config.evaluation.db_indexes # items = [dataset[i][0] for i in indices] @@ -145,12 +147,12 @@ def grid_eval_dataset(self) -> GridEvalDataset[BatchColorPalettePrompt]: # images = [item.image for item in items] # eval_indices = list(zip(indices, palette, images)) - # evaluations : list[BatchColorPalettePrompt] = [] + # evaluations : list[BatchHistogramPrompt] = [] # prompts_list = [(prompt, self.text_encoder(prompt)) for prompt in self.config.evaluation.prompts] # for (prompt, prompt_embedding) in prompts_list: # for db_index, palette, image in eval_indices: - # batch_prompt = BatchColorPalettePrompt( + # batch_prompt = BatchHistogramPrompt( # source_prompts= [prompt], # db_indexes= [db_index], # source_palettes= [palette], @@ -162,26 +164,28 @@ def grid_eval_dataset(self) -> GridEvalDataset[BatchColorPalettePrompt]: # print(f"Eval dataset size: {len(evaluations)}") # return evaluations - def build_results(self, batch: BatchColorPalettePrompt, result_images: Tensor) -> BatchColorPaletteResults: + def build_results(self, batch: BatchHistogramPrompt, result_images: Tensor) -> BatchHistogramResults: - return BatchColorPaletteResults( + return BatchHistogramResults( source_prompts=batch.source_prompts, db_indexes=batch.db_indexes, + source_histograms=batch.source_histograms, source_palettes=batch.source_palettes, + result_histograms = self.histogram_extractor(result_images), result_images=result_images, source_images=batch.source_images, result_palettes=[self.color_palette_extractor(image, size=len(batch.source_palettes[i])) for i, image in enumerate(tensor_to_images(result_images))], text_embeddings=batch.text_embeddings ) - def collate_results(self, batch: list[BatchColorPaletteResults]) -> BatchColorPaletteResults: - return BatchColorPaletteResults.collate_fn(batch) + def collate_results(self, batch: list[BatchHistogramResults]) -> BatchHistogramResults: + return BatchHistogramResults.collate_fn(batch) - def empty(self) -> BatchColorPaletteResults: - return BatchColorPaletteResults.empty() + def empty(self) -> BatchHistogramResults: + return BatchHistogramResults.empty() - def collate_prompts(self, batch: list[BatchColorPalettePrompt]) -> BatchColorPalettePrompt: - return BatchColorPalettePrompt.collate_fn(batch) + def collate_prompts(self, batch: list[BatchHistogramPrompt]) -> BatchHistogramPrompt: + return BatchHistogramPrompt.collate_fn(batch) def compute_loss(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: @@ -208,14 +212,14 @@ def compute_loss(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: return loss - def eval_set_adapter_values(self, batch: BatchColorPalettePrompt) -> None: + def eval_set_adapter_values(self, batch: BatchHistogramPrompt) -> None: self.color_palette_adapter.set_color_palette_embedding( self.color_palette_encoder.compute_color_palette_embedding( batch.source_palettes ) ) - def draw_cover_image(self, batch: BatchColorPaletteResults) -> Image.Image: + def draw_cover_image(self, batch: BatchHistogramResults) -> Image.Image: (batch_size, _, height, width) = batch.result_images.shape palette_img_size = width // self.config.color_palette.max_colors @@ -246,12 +250,13 @@ def draw_cover_image(self, batch: BatchColorPaletteResults) -> Image.Image: # return distance - def batch_metrics(self, results: BatchColorPaletteResults, prefix: str = "palette-img") -> None: + def batch_metrics(self, results: BatchHistogramResults, prefix: str = "palette-img") -> None: palettes : list[list[Color]] = [] for p in results.source_palettes: palettes.append([cluster[0] for cluster in p]) images = tensor_to_images(results.result_images) + batch_image_palette_metrics( self.log, [ @@ -260,7 +265,15 @@ def batch_metrics(self, results: BatchColorPaletteResults, prefix: str = "palett ], prefix ) - + + histo_metrics = self.histogram_distance.metrics(results.result_histograms, results.source_histograms.to(results.result_histograms.device)) + + log_dict : dict[str, WandbLoggable] = {} + for (key, value) in histo_metrics.items(): + log_dict[f"eval_histo/{key}"] = value + + self.log(log_dict) + class LoadColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]): def on_train_begin(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None: adapter = trainer.color_palette_adapter diff --git a/src/refiners/training_utils/trainers/histogram.py b/src/refiners/training_utils/trainers/histogram.py index 198cb6d14..8cbd56b66 100644 --- a/src/refiners/training_utils/trainers/histogram.py +++ b/src/refiners/training_utils/trainers/histogram.py @@ -67,12 +67,10 @@ def __init__(self, source_prompts: list[str], text_encoder: CLIPTextEncoderL, histogram_extractor: HistogramExtractor, - histogram_auto_encoder: HistogramAutoEncoder, color_palette_extractor: ColorPaletteExtractor ) -> None: super().__init__(db_indexes=db_indexes, hf_dataset=hf_dataset, source_prompts=source_prompts, text_encoder=text_encoder) self.histogram_extractor = histogram_extractor - self.histogram_auto_encoder = histogram_auto_encoder self.color_palette_extractor = color_palette_extractor def process_item(self, items: TextEmbeddingColorPaletteLatentsBatch) -> dict[str, Any]: @@ -80,12 +78,10 @@ def process_item(self, items: TextEmbeddingColorPaletteLatentsBatch) -> dict[str raise ValueError("The items must have length 1.") histograms = self.histogram_extractor.images_to_histograms([item.image for item in items]) - histogram_embeddings = self.histogram_auto_encoder.encode(histograms).reshape(histograms.shape[0], 1, -1) source_palettes = [self.color_palette_extractor(item.image, size=len(item.color_palette)) for item in items] return { "source_palettes": source_palettes, - "source_histogram_embeddings": histogram_embeddings, "source_histograms": histograms } @@ -102,7 +98,6 @@ def grid_eval_dataset(self) -> GridEvalDataset[BatchHistogramPrompt]: source_prompts=self.config.evaluation.prompts, text_encoder=self.text_encoder, histogram_extractor=self.histogram_extractor, - histogram_auto_encoder=self.histogram_auto_encoder, color_palette_extractor=self.color_palette_extractor ) @@ -129,7 +124,6 @@ def batch_metrics(self, results: BatchHistogramResults, prefix: str = "histogram )}) - [red, green, blue] = self.color_loss.image_vs_histo( results.result_images.to(device=self.device), results.source_histograms.to(device=self.device), @@ -184,33 +178,7 @@ def histogram_projection(self) -> HistogramProjection: @cached_property def color_loss(self) -> ColorLoss: return ColorLoss() - - # @cached_property - # def eval_dataset(self) -> list[BatchHistogramPrompt]: - # dataset = self.dataset - # indices = self.config.evaluation.db_indexes - # items = [dataset[i][0] for i in indices] - # palette = [item.color_palette for item in items] - # images = [item.image for item in items] - # histograms = self.histogram_extractor.images_to_histograms(images, device = self.device, dtype = self.dtype) - # histogram_embeddings = self.histogram_auto_encoder.encode(histograms).reshape(histograms.shape[0], 1, -1) - # eval_indices = list(zip(indices, histograms.split(1), histogram_embeddings.split(1), palette, images)) # type: ignore - # evaluations : list[BatchHistogramPrompt] = [] - # prompts_list = [(prompt, self.text_encoder(prompt)) for prompt in self.config.evaluation.prompts] - # for (prompt, prompt_embedding) in prompts_list: - # for db_index, histogram, histogram_embedding, palette, image in eval_indices: # type: ignore - # batch_histogram_prompt = BatchHistogramPrompt( - # source_histogram_embeddings= histogram_embedding, # type: ignore - # source_histograms= histogram, # type: ignore - # source_prompts= [prompt], - # db_indexes= [db_index], - # source_palettes= [palette], - # text_embeddings= prompt_embedding, # type: ignore - # source_images= [image] - # ) - # evaluations.append(batch_histogram_prompt) - # return evaluations - + def collate_results(self, batch: list[BatchHistogramResults]) -> BatchHistogramResults: return BatchHistogramResults.collate_fn(batch) @@ -300,15 +268,7 @@ def eval_set_adapter_values(self, batch: BatchHistogramPrompt) -> None: ) cfg_histogram_embedding2 = self.histogram_projection(cfg_histogram_embedding) self.histogram_adapter.set_histogram_embedding(cfg_histogram_embedding2) - - # TO FIX: batch eval not working here - - # uncode = self.unconditionnal_text_embedding - # unconditionnal_text_emb = uncode.repeat(batch_size, 1, 1) - # cfg_clip_text_embedding = cat([batch.text_embeddings, unconditionnal_text_emb], dim=0) - #unconditionnal_histo_embedding = self.histogram_auto_encoder.unconditionnal_embedding_like(batch.source_histogram_embeddings) - #cfg_histogram_embedding = cat([batch.source_histogram_embeddings, unconditionnal_histo_embedding], dim=0) - + def draw_curves(self, res_histo: list[float], src_histo: list[float], color: str, width: int, height: int) -> Image.Image: histo_img = Image.new(mode="RGB", size=(width, height)) @@ -334,9 +294,7 @@ def draw_curves(self, res_histo: list[float], src_histo: list[float], color: str def draw_cover_image(self, batch: BatchHistogramResults) -> Image.Image: (batch_size, channels, height, width) = batch.result_images.shape - # for i in range(batch_size): - # logger.info(f"draw_cover_image eval_images/{batch.source_prompts[i]}_{batch.db_indexes[i]} : img hash : {hash_tensor(batch.images[i])}, txt_hash: {hash_tensor(batch.text_embeddings[i])}, histo_hash: {hash_tensor(batch.source_histogram_embeddings[i])}") - + vertical_image = batch.result_images.permute(0,2,3,1).reshape(1, height*batch_size, width, channels).permute(0,3,1,2) results_histograms = batch.result_histograms