可以任意指定每层的输出维度(但由于底层是通过调节nn.conv2d
的kernel_size实现的,所以当维度设置得不太合理时(比如在MCB
中间层试图扩大后两维的维度,会导致kernel_size为负),可能会导致程序报错,所以尽量参考原论文的维度设置,如下图)
Unsupervised-Rhythm-Clustering-Embedding的本质,是在上游模型提取特征的基础上,再进行聚类的训练
对于上游特征提取,有contrastive learning或者autoencoder这两种比较经典的无监督特征提取方法
对于下游聚类训练,可见原论文中参考的DEC,通过t分布来衡量嵌入点和聚类中心的相似度q_ij
= q
和目标相似度target
= $\frac{\frac{q^{2}{ij}}{f{j}}}{\sum_{j^{'}}\frac{q^{2}{ij^{'}}}{f{j}}}$计算kl散度作为损失函数($f_j = \sum_i q_{ij}$)。
调用官方simclr实现,传入一个encoder
即可,对于encoder
的限制很宽松,只需要尾接一个命名为self.fc的nn.linear层即可,在官方的simclr实现中会将这个层替换为nn.Identity,为了便于下游任务,self.fc的in_features和预先设定的feature_dim
保持一致
和传统对比学习不同的是,DEC不需要冻住上游模型,因为dec可训练的参数只有聚类中心。经试验发现,DEC需要的epoch并不多,否则聚类簇会倾向于聚类成线
class
ClusteringLearning
need a encoder
and clustering
数据可以在google drive下载,解压到data文件夹下即可
train.py
:
python train.py --task contrastive --max_epochs 100
在runs
文件夹下得到上游任务的checkpoint
python train.py --task clustering --max_epochs 20 --upstream_checkpoint_path your_contrastive_checkpoint_path
在runs
文件夹下得到下游任务的checkpoint
plotit.py
:
python plotit.py --upstream_checkpoint_path your_contrastive_checkpoint_path
在png
文件夹下得到show.png
可视化特征的聚类效果。plotit.py
仅支持num_class <= 10的分类别绘制,如要扩展颜色数,则需在visualization.py
的color
函数里增加颜色
model/music.py
:
对音乐数据(shape=[B,1,96,800])的encoder函数
module/learning.py
:
上下游任务的训练函数