Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Nov 10, 2020
1 parent 0311f73 commit 9335732
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 32 deletions.
41 changes: 30 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ Check our paper through [ACL Anthology](https://www.aclweb.org/anthology/2020.ac
<!-- /TOC -->

## Introduction

![](pics/arch.png)

**Textbrewer** is designed for the knowledge distillation of NLP models. It provides various distillation methods and offers a distillation framework for quickly setting up experiments.

The main features of **TextBrewer** are:
Expand Down Expand Up @@ -122,6 +125,10 @@ To start distillation, users need to provide

See [Full Documentation](https://textbrewer.readthedocs.io/) for detailed usages.

### Architecture

![](pics/arch.png)

## Installation

* Requirements
Expand Down Expand Up @@ -150,6 +157,8 @@ See [Full Documentation](https://textbrewer.readthedocs.io/) for detailed usages

![](pics/distillation_workflow_en.png)

![](pics/distillation_workflow2.png)

* **Stage 1**: Preparation:
1. Train the teacher model
2. Define and initialize the student model
Expand Down Expand Up @@ -186,7 +195,7 @@ print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)

# Define an adaptor for translating the model inputs and outputs
# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
# The second and third elements of model outputs are the logits and hidden states
return {'logits': model_outputs[1],
Expand Down Expand Up @@ -218,6 +227,7 @@ with distiller:
* [examples/cmrc2018\_example](examples/cmrc2018_example) (Chinese): distillation on CMRC 2018, a Chinese MRC task, using DRCD as data augmentation.
* [examples/mnli\_example](examples/mnli_example) (English): distillation on MNLI, an English sentence-pair classification task. This example also shows how to perform multi-teacher distillation.
* [examples/conll2003_example](examples/conll2003_example) (English): distillation on CoNLL-2003 English NER task, which is in form of sequence labeling.
* [examples/msra_ner_example](examples/msra_ner_example) (Chinese): This example distills a Chinese-ELECTRA-base model on the MSRA NER task with distributed data-parallel training(single node, muliti-GPU).


## Experiments
Expand Down Expand Up @@ -291,22 +301,22 @@ We use GeneralDistiller in all the distillation experiments.
We experiment on the following typical Enlgish datasets:

| Dataset | Task type | Metrics | \#Train | \#Dev | Note |
| :---------- | -------- | ------- | ------- | ---- | ---- |
| :---------- | -------- | ------- | ------- | ---- | ---- |
| [**MNLI**](https://www.nyu.edu/projects/bowman/multinli/) | text classification | m/mm Acc | 393K | 20K | sentence-pair 3-class classification |
| [**SQuAD 1.1**](https://rajpurkar.github.io/SQuAD-explorer/) | reading comprehension | EM/F1 | 88K | 11K | span-extraction machine reading comprehension |
| [**SQuAD 1.1**](https://rajpurkar.github.io/SQuAD-explorer/) | reading comprehension | EM/F1 | 88K | 11K | span-extraction machine reading comprehension |
| [**CoNLL-2003**](https://www.clips.uantwerpen.be/conll2003/ner) | sequence labeling | F1 | 23K | 6K | named entity recognition |

We list the public results from [DistilBERT](https://arxiv.org/abs/1910.01108), [BERT-PKD](https://arxiv.org/abs/1908.09355), [BERT-of-Theseus](https://arxiv.org/abs/2002.02925), [TinyBERT](https://arxiv.org/abs/1909.10351) and our results below for comparison.

Public results:

| Model (public) | MNLI | SQuAD | CoNLL-2003 |
| :------------- | --------------- | ------------- | --------------- |
| DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
| BERT<sub>6</sub>-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | -|
| BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
| BERT<sub>3</sub>-PKD (T3) | 76.7 / 76.3 | - | -|
| TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | -|
| Model (public) | MNLI | SQuAD | CoNLL-2003 |
| :------------- | --------------- | ------------- | --------------- |
| DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
| BERT<sub>6</sub>-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | -|
| BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
| BERT<sub>3</sub>-PKD (T3) | 76.7 / 76.3 | - | -|
| TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | -|

Our results:

Expand Down Expand Up @@ -377,7 +387,7 @@ Distillers are in charge of conducting the actual experiments. The following dis
* `BasicDistiller`: **single-teacher single-task** distillation, provides basic distillation strategies.
* `GeneralDistiller` (Recommended): **single-teacher single-task** distillation, supports intermediate features matching. **Recommended most of the time**.
* `MultiTeacherDistiller`: **multi-teacher** distillation, which distills multiple teacher models (of the same task) into a single student model. **This class doesn't support Intermediate features matching.**
* `MultiTaskDistiller`: **multi-task** distillation, which distills multiple teacher models (of different tasks) into a single student. **This class doesn't support Intermediate features matching.**
* `MultiTaskDistiller`: **multi-task** distillation, which distills multiple teacher models (of different tasks) into a single student.
* `BasicTrainer`: Supervised training a single model on a labeled dataset, not for distillation. **It can be used to train a teacher model**.


Expand Down Expand Up @@ -406,9 +416,18 @@ We recommend that users use pre-trained student models whenever possible to full

**A**: Knowledge distillation usually requires more training epochs and larger learning rate than training on the labeled dataset. For example, training SQuAD on BERT-base usually takes 3 epochs with lr=3e-5; however, distillation takes 30~50 epochs with lr=1e-4. **The conclusions are based on our experiments, and you are advised to try on your own data**.

**Q**: My teacher model and student model take different inputs (they do not share vocabularies), so how can I distill?

**A**: You need to feed different batches to the teacher and the student. See the section [Feed Different batches to Student and Teacher, Feed Cached Values](https://textbrewer.readthedocs.io/en/latest/Concepts.html#feed-different-batches-to-student-and-teacher-feed-cached-values) in the full documentation.

**Q**: I have stored the logits from my teacher model. Can I use them in the distillation to save the forward pass time?

**A**: Yes, see the section [Feed Different batches to Student and Teacher, Feed Cached Values](https://textbrewer.readthedocs.io/en/latest/Concepts.html#feed-different-batches-to-student-and-teacher-feed-cached-values) in the full documentation.

## Known Issues

* ~~Multi-GPU training support is only available through `DataParallel` currently.~~
* Multi-label classification is not supported.

## Citation

Expand Down
38 changes: 26 additions & 12 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

### [TextBrewer完整文档](https://textbrewer.readthedocs.io/)

## 哈工大讯飞联合实验室(HFL)2021提前批校园招聘开始了!欢迎各位[投递简历](https://wj.qq.com/s2/6730642/762d)

## 更新

**Aug 24, 2020**
Expand Down Expand Up @@ -123,6 +121,10 @@

详细的API可参见 [完整文档](https://textbrewer.readthedocs.io/)

### TextBrewer结构

![](pics/arch.png)

## 安装

### 安装要求
Expand Down Expand Up @@ -154,6 +156,8 @@ pip install ./textbrewer

![](pics/distillation_workflow.png)

![](pics/distillation_workflow2.png)

* **Stage 1 :** 蒸馏之前的准备工作:
1. 训练**教师**模型
2. 定义与初始化**学生**模型(随机初始化,或载入预训练权重)
Expand Down Expand Up @@ -220,6 +224,7 @@ with distiller:
* [examples/cmrc2018\_example](examples/cmrc2018_example) (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。
* [examples/mnli\_example](examples/mnli_example) (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。
* [examples/conll2003_example](examples/conll2003_example) (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。
* [examples/msra_ner_example](examples/msra_ner_example) (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。

## 蒸馏效果

Expand Down Expand Up @@ -289,22 +294,22 @@ distill_config = DistillationConfig(temperature = 8, intermediate_matches = matc
在英文实验中,我们使用了如下三个典型数据集。

| Dataset | Task type | Metrics | \#Train | \#Dev | Note |
| :---------- | -------- | ------- | ------- | ---- | ---- |
| :---------- | -------- | ------- | ------- | ---- | ---- |
| [**MNLI**](https://www.nyu.edu/projects/bowman/multinli/) | 文本分类 | m/mm Acc | 393K | 20K | 句对三分类任务 |
| [**SQuAD 1.1**](https://rajpurkar.github.io/SQuAD-explorer/) | 阅读理解 | EM/F1 | 88K | 11K | 篇章片段抽取型阅读理解 |
| [**SQuAD 1.1**](https://rajpurkar.github.io/SQuAD-explorer/) | 阅读理解 | EM/F1 | 88K | 11K | 篇章片段抽取型阅读理解 |
| [**CoNLL-2003**](https://www.clips.uantwerpen.be/conll2003/ner) | 序列标注 | F1 | 23K | 6K | 命名实体识别任务 |

我们在下面两表中列出了[DistilBERT](https://arxiv.org/abs/1910.01108), [BERT-PKD](https://arxiv.org/abs/1908.09355), [BERT-of-Theseus](https://arxiv.org/abs/2002.02925), [TinyBERT](https://arxiv.org/abs/1909.10351) 等公开的蒸馏结果,并与我们的结果做对比。

Public results:

| Model (public) | MNLI | SQuAD | CoNLL-2003 |
| :------------- | --------------- | ------------- | --------------- |
| DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
| BERT<sub>6</sub>-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | -|
| BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
| BERT<sub>3</sub>-PKD (T3) | 76.7 / 76.3 | - | -|
| TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | -|
| Model (public) | MNLI | SQuAD | CoNLL-2003 |
| :------------- | --------------- | ------------- | --------------- |
| DistilBERT (T6) | 81.6 / 81.1 | 78.1 / 86.2 | - |
| BERT<sub>6</sub>-PKD (T6) | 81.5 / 81.0 | 77.1 / 85.3 | -|
| BERT-of-Theseus (T6) | 82.4/ 82.1 | - | - |
| BERT<sub>3</sub>-PKD (T3) | 76.7 / 76.3 | - | -|
| TinyBERT (T4-tiny) | 82.8 / 82.9 | 72.7 / 82.1 | -|

Our results:

Expand Down Expand Up @@ -371,7 +376,7 @@ Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
* `BasicDistiller`: 提供**单模型单任务**蒸馏方式。可用作测试或简单实验。
* `GeneralDistiller` (常用): 提供**单模型单任务**蒸馏方式,并且支持**中间层特征匹配**,一般情况下**推荐使用**
* `MultiTeacherDistiller`: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。**暂不支持中间层特征匹配**
* `MultiTaskDistiller`:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型上。**暂不支持中间层特征匹配**
* `MultiTaskDistiller`:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型
* `BasicTrainer`:用于单个模型的有监督训练,而非蒸馏。**可用于训练教师模型**

### 用户定义函数
Expand All @@ -398,9 +403,18 @@ Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

**A**: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,**我们的建议仅是基于我们的经验得出的,仅供参考**

**Q**: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?

**A**: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 [Feed Different batches to Student and Teacher, Feed Cached Values](https://textbrewer.readthedocs.io/en/latest/Concepts.html#feed-different-batches-to-student-and-teacher-feed-cached-values) 章节。

**Q**: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?

**A**: 可以, 参见完整文档中的 [Feed Different batches to Student and Teacher, Feed Cached Values](https://textbrewer.readthedocs.io/en/latest/Concepts.html#feed-different-batches-to-student-and-teacher-feed-cached-values) 章节。

## 已知问题

* ~~尚不支持DataParallel以外的多卡训练策略。~~
* 尚不支持多标签分类任务。

## 引用

Expand Down
Loading

0 comments on commit 9335732

Please sign in to comment.