To install the dependencies:
conda create --name modnas python=3.9
conda activate modnas
pip install -r requirements.txt
├── predictors
│ ├── hat
│ ├── help
│ ├── nb201
│ ├── ofa
├── hypernetworks
│ ├── models
│ ├── pretrain_hpns
├── scripts
├── optimizers
│ ├── help
│ ├──mgd
│ ├── mixop
│ ├── sampler
│ ├── optim_factory.py
├── plotting
├── search_spaces
│ ├── hat
│ ├── MobileNetV3
│ ├── nb201
The predictors
folder contains the meta predictors for different search spaces
The hypernetworks
folder contains the architectures of our hypernetworks for different search spaces
The scripts
folder contains the scripts to batch different jobs
The optimizers
folder contains the different one-shot and black box optimizers for architecture search
The plotting
folder contains the scripts used for radar plots
The search_spaces
folder contains the definition of the search spaces search spaces nasbench201, mobilenetv3, hardware aware transformers
The predictor_data_utils
and hypernetwork_data_utils
folder contains the pretrained predictors and hypernetworks respectively
The baselines
folder contains the scripts to run the synetune baselines for different search spaces.
CIFAR10
and CIFAR100
datasets will be automatically downloaded
Download the imagenet-1k
from here and update the path to the dataset in the training script. The dataset Imagenet16-120
Follow the instructions here to download the binary files for the different machine translation datasets.
python hypernetworks/pretrain_hpns/pretrain_hpns_nb201.py
python hypernetworks/pretrain_hpns/pretrain_hpns_ofa.py
python hypernetworks/pretrain_hpns/pretrain_hpns_hat.py
python predictors/nb201/train/train_predictor.py
python predictors/ofa/train/train_ofa_predictor.py
python predictors/hat/train_latency_predictor.py --task wmt14.en-de
python search_spaces/nb201/search_nb201_mgd.py \
--save mgd-100epochs \
--wandb_name "modnas-nb201-100epochs" \
--optimizer_type "reinmax" \
--arch_weight_decay 0.09 \
--train_portion 0.5 \
--learning_rate 0.025 \
--learning_rate_min 0.001 \
--seed 9001 \
--epochs 100 \
--load_path "predictor_data_utils/nb201/predictor_meta_learned.pth" \
--w_grad_update_method "mean" \
--hpn_grad_update_method "mgd" \
--weight_decay 0.0027
python -m torch.distributed.launch --nproc_per_node=8 --use_env search_spaces/MobileNetV3/search/mobilenet_search_base.py --one_shot_opt reinmax --opt_strategy "simultaneous" --hpn_type meta --use_pretrained_hpn
python search_spaces/hat/train.py --configs=search_spaces/hat/configs/wmt14.en-de/supertransformer/space0.yml
Download the pretrained supernet from HW-GPT-Bench. Download device embeddings and processed dataset from here
python -u search_spaces/gpt/train_llm_configurable_scratch.py -c juwels_owt_sw_s.yaml
The scripts/
folder contains slurm scripts to launch the above search scripts
For NB201 we evaluate the hypervolume during training
For MobileNetV3 and HAT search spaces, obtain the architectures on the Pareto-Front using:
python evaluation/get_archs_ofa.py
python evaluation/get_archs_hat.py --config-file search_spaces/hat/configs/wmt14.en-de/supertransformer/space0.yml --arch transformersuper_wmt_en_de2
To evaluate the archs from MobileNetV3 space:
python search_spaces/MobileNetV3/eval_ofa_net.py --net ofa_mbv3_d234_e346_k357_w1.2
To evaluate the archs from HAT space:
python search_spaces/hat/eval_archs_hat.py
We then use the scoring protocol for BLEU and Sacre-BLEU from HAT to score the evaluations
To make the radar plots we use the file plotting/plot_radar.py
The baselines
folder and the baseline_scripts
folder contains the code to run different baselines on all different search spaces.