Skip to content

Commit

Permalink
Merge branch 'handle-url-hf-datasets' into fabric
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 12, 2024
2 parents 636b442 + 55624a2 commit 46ddaa4
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
9 changes: 5 additions & 4 deletions configs/finetune-lora.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ entity = "acme"
project = "test-lora-training"

[models]
unet = {checkpoint = "/path/to/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "/path/to/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "/path/to/stable-diffusion-1-5/lda.safetensors"}
unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors"}

[latent_diffusion]
unconditional_sampling_probability = 0.05
Expand Down Expand Up @@ -50,8 +50,9 @@ dropout_probability = 0.2
use_gyro_dropout = false

[dataset]
hf_repo = "acme/images"
hf_repo = "1aurent/unsplash-lite-palette"
revision = "main"
caption_key = "ai_description"

[checkpointing]
# save_folder = "/path/to/ckpts"
Expand Down
2 changes: 1 addition & 1 deletion scripts/training/finetune-ldm-textual-inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
)
self.placeholder_token = self.config.textual_inversion.placeholder_token

def get_caption(self, index: int) -> str:
def get_caption(self, index: int, caption_key: str) -> str:
# Ignore the dataset caption, if any: use a template instead
return random.choice(self.templates).format(self.placeholder_token)

Expand Down
3 changes: 2 additions & 1 deletion src/refiners/training_utils/huggingface_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Generic, Protocol, TypeVar, cast

from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from datasets import DownloadManager, VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore

__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
Expand Down Expand Up @@ -34,3 +34,4 @@ class HuggingfaceDatasetConfig(BaseModel):
use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576
caption_key: str = "caption"
19 changes: 15 additions & 4 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import cached_property
from typing import Any, Callable, TypedDict, TypeVar

from datasets import DownloadManager # type: ignore
from loguru import logger
from PIL import Image
from pydantic import BaseModel
Expand Down Expand Up @@ -58,6 +59,7 @@ class TextEmbeddingLatentsBatch:
class CaptionImage(TypedDict):
caption: str
image: Image.Image
url: str


ConfigType = TypeVar("ConfigType", bound=FinetuneLatentDiffusionConfig)
Expand All @@ -72,6 +74,8 @@ def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset()
self.process_image = self.build_image_processor()
self.download_manager = DownloadManager()

logger.info(f"Loaded {len(self.dataset)} samples from dataset")

def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
Expand All @@ -98,14 +102,21 @@ def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int =
def process_caption(self, caption: str) -> str:
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""

def get_caption(self, index: int) -> str:
return self.dataset[index]["caption"]
def get_caption(self, index: int, caption_key: str) -> str:
return self.dataset[index][caption_key] # type: ignore

def get_image(self, index: int) -> Image.Image:
return self.dataset[index]["image"]
if "image" in self.dataset[index]:
return self.dataset[index]["image"]
elif "url" in self.dataset[index]:
url : str = self.dataset[index]["url"]
filename : str = self.download_manager.download(url) # type: ignore
return Image.open(filename)
else:
raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'")

def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
caption = self.get_caption(index=index)
caption = self.get_caption(index=index, caption_key=self.config.dataset.caption_key)
image = self.get_image(index=index)
resized_image = self.resize_image(
image=image,
Expand Down

0 comments on commit 46ddaa4

Please sign in to comment.