使用Swin-Transformer + Query2Label的方案解决Plant Pathology 2021 - FGVC8的叶片分类问题,该方案模型可以被简称为QST。
实现路径以及代码详细解释可见博客cleversmall。
由于独热编码对数据集进行了改动,故需要对仓库中的plant_dataset
文件夹和下载所得到的plant_dataset
文件夹进行合并,合并后的格式应该如下。
├─train
| ├─images
| ├─train_label.csv
| ├─labels.csv
├─val
| ├─images
| ├─val_label.csv
| ├─labels.csv
├─test
| ├─images
| ├─test_label.csv
| ├─labels.csv
本项目所要求的python环境并无特殊,你可以简单的在自己的torch环境中进行配置也可。
pip install -r requirements.txt
python train.py --data-path <path of plant_dataset> --weights <path of your weight>
其中weights不进行设置则说明从随机参数开始训练。
python predict.py --img-path <path of plant_dataset> --weights <path of your weight>
其在终端输出的结果例子可以是
tensor([[ 2.2923, -0.8050, -0.4738, -0.5310, -0.2975, -1.1503]])
tensor([[0.9082, 0.3090, 0.3837, 0.3703, 0.4262, 0.2404]])
两行的意思一致,分别代表了属于不同标签分类的原始概率和经过sigmoid函数后的概率,其中类别1-6分别对应的是
scab | healthy | frog_eye_leaf_spot | rust | complex | powdery_mildew |
---|
则说明该叶片最可能属于scab。
在进行输入前,首先需要将12类的混合分类简化为6类,这会导致多标签的情况,故使用独热码对标签进行分类。
模型即使用最基本的SwinTransformer并在后端加入Query2Label以优化多标签分类结果。
PR曲线
F1分数曲线