From 066637f8129c8e1eb4d3e16335ba13045eb487bd Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 16 Feb 2024 09:31:25 +0100 Subject: [PATCH] sometimes, cluster count is less than size --- .../scheduled/finetune-histogram-jiont-training.toml | 0 .../train-histogram-autoencoder-ckpt-reload.toml | 4 ++-- scripts/training/scheduler-local.bash | 4 ++-- src/refiners/fluxion/adapters/color_palette.py | 10 ++++++++-- src/refiners/fluxion/adapters/histogram.py | 1 - 5 files changed, 12 insertions(+), 7 deletions(-) delete mode 100644 configs/local/scheduled/finetune-histogram-jiont-training.toml diff --git a/configs/local/scheduled/finetune-histogram-jiont-training.toml b/configs/local/scheduled/finetune-histogram-jiont-training.toml deleted file mode 100644 index e69de29bb..000000000 diff --git a/configs/local/scheduled/train-histogram-autoencoder-ckpt-reload.toml b/configs/local/scheduled/train-histogram-autoencoder-ckpt-reload.toml index 8793bfbd4..f7169c7d4 100644 --- a/configs/local/scheduled/train-histogram-autoencoder-ckpt-reload.toml +++ b/configs/local/scheduled/train-histogram-autoencoder-ckpt-reload.toml @@ -15,10 +15,10 @@ num_groups = 4 loss = "kl_div" [models] -histogram_auto_encoder = {train = true, checkpoint = "tmp/ckpt-reload-step6000.safetensors"} +histogram_auto_encoder = {train=true, checkpoint = "tmp/ckpt-reload-step6000.safetensors"} [training] -duration = "10:epoch" +duration = "20:epoch" seed = 0 gpu_index = 1 batch_size = 8 diff --git a/scripts/training/scheduler-local.bash b/scripts/training/scheduler-local.bash index 05773bf52..2a8ee33f7 100644 --- a/scripts/training/scheduler-local.bash +++ b/scripts/training/scheduler-local.bash @@ -1,9 +1,9 @@ #!/bin/bash # Path to the directory containing the config files -config_dir="./configs/histogram-auto-encoder" +config_dir="./configs/remote/scheduled" prefix="" -script="./scripts/training/train-histogram-autoencoder.py" +script="./scripts/training/train-color-palette.py" # Log file path log_file="./tmp/schedule-log.txt" diff --git a/src/refiners/fluxion/adapters/color_palette.py b/src/refiners/fluxion/adapters/color_palette.py index cba4cf63d..ebd344397 100644 --- a/src/refiners/fluxion/adapters/color_palette.py +++ b/src/refiners/fluxion/adapters/color_palette.py @@ -219,12 +219,12 @@ def __call__(self, image: Image.Image, size: int | None = None) -> ColorPalette: image_np = np.array(image) pixels = image_np.reshape(-1, 3) return self.from_pixels(pixels, size) - def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalette: + def from_pixels(self, pixels: np.ndarray, size: int | None = None, eps : float = 1e-7) -> ColorPalette: kmeans = KMeans(n_clusters=size).fit(pixels) # type: ignore counts = np.unique(kmeans.labels_, return_counts=True)[1] # type: ignore palette : ColorPalette = [] total = pixels.shape[0] - for i in range(0, size): + for i in range(0, len(counts)): center_float : tuple[float, float, float] = kmeans.cluster_centers_[i] # type: ignore center : Color = tuple(center_float.astype(int)) # type: ignore count = float(counts[i].item()) @@ -233,6 +233,12 @@ def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalet count / total if self.weighted_palette else 1.0 / size ) palette.append(color_cluster) + + if len(counts) < size: + for _ in range(size - len(counts)): + pal : ColorPaletteCluster = ((0, 0, 0), eps if self.weighted_palette else 1.0 / size) + palette.append(pal) + sorted_palette = sorted(palette, key=lambda x: x[1], reverse=True) return sorted_palette diff --git a/src/refiners/fluxion/adapters/histogram.py b/src/refiners/fluxion/adapters/histogram.py index b2df3ba9c..fe54fe092 100644 --- a/src/refiners/fluxion/adapters/histogram.py +++ b/src/refiners/fluxion/adapters/histogram.py @@ -137,7 +137,6 @@ def emd(self, x: Tensor, y: Tensor) -> Tensor: s_y = sample_points(x) s_x = sample_points(y) emd = emd_loss(s_x, s_y) - print(f"EMD: {emd.mean()}, shape: {emd.shape}") return emd.mean() def correlation(self, x: Tensor, y: Tensor) -> Tensor: