Skip to content

Commit

Permalink
num_workers in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 26, 2024
1 parent 8deaeb0 commit ff49c65
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
1 change: 1 addition & 0 deletions configs/finetune-color-palette-piercus.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ clip_grad_norm = 1.0
# clip_grad_value = 1.0
evaluation_interval = "250:step"
evaluation_seed = 1
num_workers = 8

[optimizer]
optimizer = "AdamW" # "SGD", "Adam", "AdamW", "AdamW8bit", "Lion8bit"
Expand Down
1 change: 1 addition & 0 deletions src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class TrainingConfig(BaseModel):
gpu_index: int = 0
dtype: str = "float32"
batch_size: int = 1
num_workers: int = 0
gradient_accumulation: TimeValue = {"number": 1, "unit": TimeUnit.STEP}
clip_grad_norm: float | None = None
clip_grad_value: float | None = None
Expand Down
13 changes: 3 additions & 10 deletions src/refiners/training_utils/datasets/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,28 +54,21 @@ def __init__(
)

def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch:
logger.info(f"Getting latents {index}")

item : DatasetItem = self.hf_dataset[index]
logger.info(f"Getting item {index}")

resized_image = self.resize_image(
image=item["image"],
min_size=self.config.resize_image_min_size,
max_size=self.config.resize_image_max_size,
)
logger.info(f"resized_image image {index}")

image = self.process_image(resized_image)

logger.info(f"process_image image {index}")

caption_key = self.config.caption_key or "caption"
caption_key = self.config.caption_key
caption = item[caption_key]

(caption_processed, conditional_flag) = self.process_caption(self.get_caption(index))
logger.info(f"process_caption {index}")

(caption_processed, conditional_flag) = self.process_caption(caption)

return [
ColorPaletteDatasetItem(
color_palette=self.get_color_palette(index),
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/training_utils/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def dataset_length(self) -> int:
def dataloader(self) -> DataLoader[Batch]:
collate_fn = getattr(self.dataset, "collate_fn", None)
return DataLoader(
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=self.config.training.num_workers
)

@property
Expand Down

0 comments on commit ff49c65

Please sign in to comment.