From ff49c653a7f697291ae0601160fa8eb621a4ae07 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 26 Jan 2024 21:41:46 +0100 Subject: [PATCH] num_workers in dataloader --- configs/finetune-color-palette-piercus.toml | 1 + src/refiners/training_utils/config.py | 1 + .../training_utils/datasets/color_palette.py | 13 +++---------- src/refiners/training_utils/trainers/trainer.py | 2 +- 4 files changed, 6 insertions(+), 11 deletions(-) diff --git a/configs/finetune-color-palette-piercus.toml b/configs/finetune-color-palette-piercus.toml index 8f6111bed..816078ee2 100644 --- a/configs/finetune-color-palette-piercus.toml +++ b/configs/finetune-color-palette-piercus.toml @@ -35,6 +35,7 @@ clip_grad_norm = 1.0 # clip_grad_value = 1.0 evaluation_interval = "250:step" evaluation_seed = 1 +num_workers = 8 [optimizer] optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit" diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 664585975..e7732264e 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -55,6 +55,7 @@ class TrainingConfig(BaseModel): gpu_index: int = 0 dtype: str = "float32" batch_size: int = 1 + num_workers: int = 0 gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP} clip_grad_norm: float | None = None clip_grad_value: float | None = None diff --git a/src/refiners/training_utils/datasets/color_palette.py b/src/refiners/training_utils/datasets/color_palette.py index 5f8965192..aacaf1feb 100644 --- a/src/refiners/training_utils/datasets/color_palette.py +++ b/src/refiners/training_utils/datasets/color_palette.py @@ -54,28 +54,21 @@ def __init__( ) def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch: - logger.info(f"Getting latents {index}") item : DatasetItem = self.hf_dataset[index] - logger.info(f"Getting item {index}") resized_image = self.resize_image( image=item["image"], min_size=self.config.resize_image_min_size, max_size=self.config.resize_image_max_size, ) - logger.info(f"resized_image image {index}") image = self.process_image(resized_image) - logger.info(f"process_image image {index}") - - caption_key = self.config.caption_key or "caption" + caption_key = self.config.caption_key caption = item[caption_key] - - (caption_processed, conditional_flag) = self.process_caption(self.get_caption(index)) - logger.info(f"process_caption {index}") - + (caption_processed, conditional_flag) = self.process_caption(caption) + return [ ColorPaletteDatasetItem( color_palette=self.get_color_palette(index), diff --git a/src/refiners/training_utils/trainers/trainer.py b/src/refiners/training_utils/trainers/trainer.py index f29f2f3db..54f0cc81f 100644 --- a/src/refiners/training_utils/trainers/trainer.py +++ b/src/refiners/training_utils/trainers/trainer.py @@ -512,7 +512,7 @@ def dataset_length(self) -> int: def dataloader(self) -> DataLoader[Batch]: collate_fn = getattr(self.dataset, "collate_fn", None) return DataLoader( - dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn + dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=self.config.training.num_workers ) @property