OFA-Compress is a unified framework for OFA compression. It provides OFA model finetuning, distillation and inference capabilities in Huggingface version, and is committed to promoting the lightweighting of large models.
- ofa: provides OFA model implemented on huggingface transformers.
- data_utils: provides a
OFADataset
that subclassestorch.utils.data.Dataset
to process data into samples and label, and implement classes specific to the particular tasks (e.g., caption_dataset.py, refcoco_dataset.py, snli_ve_dataset.py, etc). - scripts: provides evaluate, finetune and distill shell scripts specific to the particular task.
- train: provides functions to execute models.
- textbrewer: a PyTorch-based knowledge distillation toolkit for natural language processing. Check https://github.com/airaria/TextBrewer
- generate: the sequence generator implemented on the Fairseq codebase.
- python 3.6
- pytorch 1.8
- torchvision 0.9.1
- transformers 4.16.2
- datasets 1.17.0
- pillow 8.3.2
We welcome contributions to our project. Feel free to contact us or send us issues/PRs!
Below we demonstrate the results of OFAs on cross-modal tasks.
Task | Image Captioning | Visual Entailment | Referring Expression Comprehension | ||
---|---|---|---|---|---|
Dataset | COCO | SNLI-VE | RefCOCO | RefCOCO+ | RefCOCOg |
Split | Karpathy test (CE) | val/test | val/test-a/test-b | val/test-a/test-b | val-u/test-u |
Metric | CIDEr | Acc. | Acc. | ||
OFATiny | 119.0 | 85.3 / 85.2 | 80.20 / 84.07 / 75.00 | 68.22 / 75.13 / 57.66 | 72.02 / 69.74 |
OFA-CompressTiny | 120.0 | 87.0 / 86.9 | 81.29 / 85.18 / 75.29 | 71.28 / 77.08 / 61.13 | 72.08 / 71.67 |
git clone https://github.com/OFA-Sys/OFA-Compress
pip install -r requirements.txt
See datasets.md and checkpoints.md.
Below we provide methods for finetuning, distillation and inference on different downstream tasks.
To use OFA-Compress, you should first download the dataset and pretrained checkpoints in the OFA repository (see checkpoints.md and datasets.md).
Since the checkpoints are trained in Fairseq framework, we provide a script convert_ofa_original_ckpt_to_huggingface.py
to convert the original ckpt to Huggingface version.
python convert_ofa_original_ckpt_to_huggingface.py --pt_model /xxx/ofa-refcoco-large/refcoco_large_best.pt --hf_model_dir /xxx/ofa-refcoco-large/
To finetune OFA, you should set the ${init_method}
to 'load_pretrain', and the framework will load the pretrained ckpt from ${load}
you set.
We provide the finetuning scripts as following:
cd scripts/finetune bash caption_finetune.sh # Image caption task. For refcoco and snli-ve, use refcoco_finetune.sh and snlive_finetune.sh
To start task-specific distillation, you need to provide the finetuned teacher model and the un-trained or pretrained student model in model_paths.py
.
Then, you should setup the configuration for distillation, such as knowledge distillation loss ${kd_loss_type}
, layer matches ${intermediate_matches}$
, etc.
We provide the distillation scripts as following:
cd scripts/distill bash caption_distill.sh # Image caption task. For refcoco and snli-ve, use refcoco_distill.sh and snlive_distill.sh
from ofa.modeling_ofa import OFAModel
from criterions import AdjustLabelSmoothedCrossEntropyCriterion
from ofa_distill import OFADistiller
from ofa_distill import OFADistillationConfig
from textbrewer import TrainingConfig
output_dict = {
"output_attentions": True,
"output_hidden_states": True
}
model_T = OFAModel.from_pretrained("ofa-caption-large-stage1", **output_dict)
model_S = OFAModel.from_pretrained("ofa-tiny", **output_dict)
def simple_adaptor(batch, model_outputs):
outputs = {}
criterion = AdjustLabelSmoothedCrossEntropyCriterion()
loss, sample_size, logging_output = criterion(model_outputs, batch)
outputs["losses"] = loss / logging_output['sample_size']
outputs["sample_size"] = logging_output['sample_size']
outputs["target"] = batch["target"]
if "constraint_masks" in batch:
outputs["constraint_masks"] = batch["constraint_masks"]
for k1, k2 in zip(["encoder_attentions", "decoder_attentions",
"encoder_hidden_states", "decoder_hidden_states",
"encoder_last_hidden_state", "logits",
"cross_attentions"],
["encoder_attention", "decoder_attention",
"encoder_hidden", "decoder_hidden",
"encoder_last", "logits",
"cross_attention"]):
if k1 in model_outputs:
outputs[k2] = model_outputs[k1]
return outputs
# Training configuration
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
distill_config = OFADistillationConfig(
text_preprocessor=args.tokenizer,
temperature=args.temperature,
temperature_scheduler=args.temperature_scheduler,
hard_label_weight=args.hard_label_weight,
hard_label_weight_scheduler=args.hard_label_weight_scheduler,
kd_loss_type=args.kd_loss_type,
kd_loss_weight=args.kd_loss_weight,
kd_loss_weight_scheduler=args.kd_loss_weight_scheduler,
probability_shift=args.probability_shift,
intermediate_matches=args.intermediate_matches,
is_caching_logits=args.is_caching_logits,
constraint_range=args.constraint_range)
# Build distiller
distiller = OFADistiller(train_config, distill_config, model_T,
model_S, simple_adaptor, simple_adaptor)
# Start!
with distiller:
distiller.train(optimizer,
scheduler_class=scheduler_class,
scheduler_args=scheduler_args,
max_grad_norm=1.0,
dataloader=train_loader,
num_epochs=10,
callback=None)
To evaluate your models, you should first provide the model ckpt in the ${load}
configuration.
We provide the inference scripts as following:
cd scripts/evaluate bash caption_evaluate.sh # Image caption task. For refcoco and snli-ve, use refcoco_evaluate.sh and snlive_evaluate.sh
Feel free to submit Github issues or pull requests. Welcome to contribute to our project!
To contact us, never hestitate to send an email to [email protected]
or [email protected]
!
Please cite our paper if you find it helpful :)
@article{Lu2022KnowledgeDO,
author = {Chengqiang Lu and
Jianwei Zhang and
Yunfei Chu and
Zhengyu Chen and
Jingren Zhou and
Fei Wu and
Haiqing Chen and
Hongxia Yang},
title = {Knowledge Distillation of Transformer-based Language Models Revisited},
journal = {ArXiv},
volume = {abs/2206.14366}
year = {2022}
}
@article{wang2022ofa,
author = {Peng Wang and
An Yang and
Rui Men and
Junyang Lin and
Shuai Bai and
Zhikang Li and
Jianxin Ma and
Chang Zhou and
Jingren Zhou and
Hongxia Yang},
title = {OFA: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence
Learning Framework},
journal = {CoRR},
volume = {abs/2202.03052},
year = {2022}
}