Simple code to visualize attention values of Transformer-based language model.
The main idea of handling attention values comes from ACL-IJCNLP paper LEWIS: Levenshtein Editing for Unsupervised Text Style Transfer's codebase. According to the paper, penultimate (second to the last) layer worked the best. (e.g. 11th layer for Roberta-base.)
conda env create -n <name> -f requirements.txt
- If you want GPU-enabled torch,
conda activate <name>
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
(check url)
- If you want GPU-enabled torch,
- Or simply check if packages in requirements.txt are already installed in your environment.
conda activate <name>
python viz_attention.py
- Or open demo.ipynb and run it for demo.
- Data: SST-2 test set
- Model: distilbert-base-uncased-finetuned-sst-2-english