From d9b124230bfb4b1492c537a1059466d23a00a156 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 12 Jan 2024 16:00:07 +0100 Subject: [PATCH 1/2] url dataset --- configs/finetune-lora.toml | 3 ++- .../training_utils/huggingface_datasets.py | 3 ++- .../training_utils/latent_diffusion.py | 18 ++++++++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/configs/finetune-lora.toml b/configs/finetune-lora.toml index 4d4409cbd..8617c50ec 100644 --- a/configs/finetune-lora.toml +++ b/configs/finetune-lora.toml @@ -50,8 +50,9 @@ dropout_probability = 0.2 use_gyro_dropout = false [dataset] -hf_repo = "acme/images" +hf_repo = "1aurent/unsplash-lite-palette" revision = "main" +caption_key = "ai_description" [checkpointing] # save_folder = "/path/to/ckpts" diff --git a/src/refiners/training_utils/huggingface_datasets.py b/src/refiners/training_utils/huggingface_datasets.py index 3e73ad8c8..928269e62 100644 --- a/src/refiners/training_utils/huggingface_datasets.py +++ b/src/refiners/training_utils/huggingface_datasets.py @@ -1,6 +1,6 @@ from typing import Any, Generic, Protocol, TypeVar, cast -from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore +from datasets import DownloadManager, VerificationMode, load_dataset as _load_dataset # type: ignore from pydantic import BaseModel # type: ignore __all__ = ["load_hf_dataset", "HuggingfaceDataset"] @@ -34,3 +34,4 @@ class HuggingfaceDatasetConfig(BaseModel): use_verification: bool = False resize_image_min_size: int = 512 resize_image_max_size: int = 576 + caption_key: str = "caption" diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index f4f8ccf0a..566f91439 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -3,6 +3,7 @@ from functools import cached_property from typing import Any, Callable, TypedDict, TypeVar +from datasets import DownloadManager from loguru import logger from PIL import Image from pydantic import BaseModel @@ -72,6 +73,8 @@ def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: self.text_encoder = self.trainer.text_encoder self.dataset = self.load_huggingface_dataset() self.process_image = self.build_image_processor() + self.download_manager = DownloadManager() + logger.info(f"Loaded {len(self.dataset)} samples from dataset") def build_image_processor(self) -> Callable[[Image.Image], Image.Image]: @@ -98,14 +101,21 @@ def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = def process_caption(self, caption: str) -> str: return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else "" - def get_caption(self, index: int) -> str: - return self.dataset[index]["caption"] + def get_caption(self, index: int, caption_key: str) -> str: + return self.dataset[index][caption_key] def get_image(self, index: int) -> Image.Image: - return self.dataset[index]["image"] + if "image" in self.dataset[index]: + return self.dataset[index]["image"] + elif "url" in self.dataset[index]: + url = self.dataset[index]["url"] + filename = self.download_manager.download(url) + return Image.open(filename) + else: + raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'") def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: - caption = self.get_caption(index=index) + caption = self.get_caption(index=index, caption_key=self.config.dataset.caption_key) image = self.get_image(index=index) resized_image = self.resize_image( image=image, From 55624a2097098345787ec4fb885b06ad87afb697 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 12 Jan 2024 17:56:18 +0100 Subject: [PATCH 2/2] fix: pyright --- configs/finetune-lora.toml | 6 +++--- scripts/training/finetune-ldm-textual-inversion.py | 2 +- src/refiners/training_utils/latent_diffusion.py | 9 +++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/configs/finetune-lora.toml b/configs/finetune-lora.toml index 8617c50ec..ba786ac28 100644 --- a/configs/finetune-lora.toml +++ b/configs/finetune-lora.toml @@ -4,9 +4,9 @@ entity = "acme" project = "test-lora-training" [models] -unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"} -text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} -lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"} +unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors"} +text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} +lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors"} [latent_diffusion] unconditional_sampling_probability = 0.05 diff --git a/scripts/training/finetune-ldm-textual-inversion.py b/scripts/training/finetune-ldm-textual-inversion.py index 18142b3a0..9627e6001 100644 --- a/scripts/training/finetune-ldm-textual-inversion.py +++ b/scripts/training/finetune-ldm-textual-inversion.py @@ -84,7 +84,7 @@ def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: ) self.placeholder_token = self.config.textual_inversion.placeholder_token - def get_caption(self, index: int) -> str: + def get_caption(self, index: int, caption_key: str) -> str: # Ignore the dataset caption, if any: use a template instead return random.choice(self.templates).format(self.placeholder_token) diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 566f91439..397c680ed 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -3,7 +3,7 @@ from functools import cached_property from typing import Any, Callable, TypedDict, TypeVar -from datasets import DownloadManager +from datasets import DownloadManager # type: ignore from loguru import logger from PIL import Image from pydantic import BaseModel @@ -59,6 +59,7 @@ class TextEmbeddingLatentsBatch: class CaptionImage(TypedDict): caption: str image: Image.Image + url: str ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig) @@ -102,14 +103,14 @@ def process_caption(self, caption: str) -> str: return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else "" def get_caption(self, index: int, caption_key: str) -> str: - return self.dataset[index][caption_key] + return self.dataset[index][caption_key] # type: ignore def get_image(self, index: int) -> Image.Image: if "image" in self.dataset[index]: return self.dataset[index]["image"] elif "url" in self.dataset[index]: - url = self.dataset[index]["url"] - filename = self.download_manager.download(url) + url : str = self.dataset[index]["url"] + filename : str = self.download_manager.download(url) # type: ignore return Image.open(filename) else: raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'")