This repository contains the code for our EMNLP2022 paper JANUS: Joint Autoregressive and Non-autoregressive Training with Auxiliary Loss for Sequence Generation.
JANUS is a new training strategy to enhance the model performance in both AR and NAR manner simultaneously and effectively alleviate the problem of distribution discrepancy.
The problem of distribution discrepancy:
JANUS:
You can init environment via the toolkit of conda and pip, and choose a suitable Pytorch version according to your platform:
# For example:
conda create -n janus python==3.8
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
Install this project:
pip install --editable .
Download and process the data:
-
IWSLT14 De-En: We use the fairseq scripts prepare-iwslt14.sh to download data and preprocess the dataset following the instructions.
-
WMT16: We use the processed data thanks to the DisCo authors and project.
-
GLGE: We use the GLGE benchmark for general NLG tasks.
Model Training with JANUS:
bash janus/scripts/iwslt/run_janus.sh
Model Inference with JANUS:
bash janus/scripts/iwslt/run_janus_nar_inf.sh
bash janus/scripts/iwslt/run_janus_ar_inf.sh
# results:
+---------+-------------+-------------+
| IWSLT14 | De->En (AR) | De->En (NAR)|
+---------+-------------+-------------+
| - | 37.24 | 34.21 |
+---------+-------------+-------------+