Skip to content

Commit

Permalink
sometimes, cluster count is less than size
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Feb 16, 2024
1 parent bbd3436 commit 066637f
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ num_groups = 4
loss = "kl_div"

[models]
histogram_auto_encoder = {train = true, checkpoint = "tmp/ckpt-reload-step6000.safetensors"}
histogram_auto_encoder = {train=true, checkpoint = "tmp/ckpt-reload-step6000.safetensors"}

[training]
duration = "10:epoch"
duration = "20:epoch"
seed = 0
gpu_index = 1
batch_size = 8
Expand Down
4 changes: 2 additions & 2 deletions scripts/training/scheduler-local.bash
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/bin/bash

# Path to the directory containing the config files
config_dir="./configs/histogram-auto-encoder"
config_dir="./configs/remote/scheduled"
prefix=""
script="./scripts/training/train-histogram-autoencoder.py"
script="./scripts/training/train-color-palette.py"
# Log file path
log_file="./tmp/schedule-log.txt"

Expand Down
10 changes: 8 additions & 2 deletions src/refiners/fluxion/adapters/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def __call__(self, image: Image.Image, size: int | None = None) -> ColorPalette:
image_np = np.array(image)
pixels = image_np.reshape(-1, 3)
return self.from_pixels(pixels, size)
def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalette:
def from_pixels(self, pixels: np.ndarray, size: int | None = None, eps : float = 1e-7) -> ColorPalette:
kmeans = KMeans(n_clusters=size).fit(pixels) # type: ignore
counts = np.unique(kmeans.labels_, return_counts=True)[1] # type: ignore
palette : ColorPalette = []
total = pixels.shape[0]
for i in range(0, size):
for i in range(0, len(counts)):
center_float : tuple[float, float, float] = kmeans.cluster_centers_[i] # type: ignore
center : Color = tuple(center_float.astype(int)) # type: ignore
count = float(counts[i].item())
Expand All @@ -233,6 +233,12 @@ def from_pixels(self, pixels: np.ndarray, size: int | None = None) -> ColorPalet
count / total if self.weighted_palette else 1.0 / size
)
palette.append(color_cluster)

if len(counts) < size:
for _ in range(size - len(counts)):
pal : ColorPaletteCluster = ((0, 0, 0), eps if self.weighted_palette else 1.0 / size)
palette.append(pal)

sorted_palette = sorted(palette, key=lambda x: x[1], reverse=True)
return sorted_palette

Expand Down
1 change: 0 additions & 1 deletion src/refiners/fluxion/adapters/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def emd(self, x: Tensor, y: Tensor) -> Tensor:
s_y = sample_points(x)
s_x = sample_points(y)
emd = emd_loss(s_x, s_y)
print(f"EMD: {emd.mean()}, shape: {emd.shape}")
return emd.mean()

def correlation(self, x: Tensor, y: Tensor) -> Tensor:
Expand Down

0 comments on commit 066637f

Please sign in to comment.