Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Commit

Permalink
DatasetHparams: added fields subset_start, subset_end, and subset_stride
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongu committed Aug 30, 2022
1 parent dc8c318 commit 9aa3a37
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
11 changes: 10 additions & 1 deletion datasets/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# LICENSE file in the root directory of this source tree.

import numpy as np

from torch.utils.data import Subset
from datasets import base, cifar10, mnist, imagenet
from foundations.hparams import DatasetHparams
from platforms.platform import get_platform


registered_datasets = {'cifar10': cifar10, 'mnist': mnist, 'imagenet': imagenet}


Expand All @@ -31,6 +32,14 @@ def get(dataset_hparams: DatasetHparams, train: bool = True):
if train and dataset_hparams.random_labels_fraction is not None:
dataset.randomize_labels(seed=seed, fraction=dataset_hparams.random_labels_fraction)

if dataset_hparams.subset_start is not None or dataset_hparams.subset_stride != 1 or dataset_hparams.subset_end is not None:
if dataset_hparams.subsample_fraction is not None:
raise ValueError("Cannot have both subsample_fraction and subset_[start,end,stride]")
subset_start = 0 if dataset_hparams.subset_start is None else dataset_hparams.subset_start
subset_end = len(dataset) if dataset_hparams.subset_end is None else dataset_hparams.subset_end
subset_stride = 1 if dataset_hparams.subset_stride is None else dataset_hparams.subset_stride
dataset = Subset(dataset, np.arange(subset_start, subset_end, subset_stride))

if train and dataset_hparams.subsample_fraction is not None:
dataset.subsample(seed=seed, fraction=dataset_hparams.subsample_fraction)

Expand Down
6 changes: 6 additions & 0 deletions foundations/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class DatasetHparams(Hparams):
do_not_augment: bool = False
transformation_seed: int = None
subsample_fraction: float = None
subset_start: int = None
subset_end: int = None
subset_stride: int = 1
random_labels_fraction: float = None
unsupervised_labels: str = None
blur_factor: int = None
Expand All @@ -140,6 +143,9 @@ class DatasetHparams(Hparams):
_transformation_seed: str = 'The random seed that controls dataset transformations like ' \
'random labels, subsampling, and unsupervised labels.'
_subsample_fraction: str = 'Subsample the training set, retaining the specified fraction: float in (0, 1]'
_subset_start: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)'
_subset_end: str = 'If set, use a Subset with indices range(subset_start, subset_end, subset_stride)'
_subset_stride: str = 'Stride of subset indices (default 1)'
_random_labels_fraction: str = 'Apply random labels to a fraction of the training set: float in (0, 1]'
_unsupervised_labels: str = 'Replace the standard labels with alternative, unsupervised labels. Example: rotation'
_blur_factor: str = 'Blur the training set by downsampling and then upsampling by this multiple.'
Expand Down

0 comments on commit 9aa3a37

Please sign in to comment.